forked from stanfordnlp/dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrefine.py
More file actions
203 lines (170 loc) · 9.35 KB
/
refine.py
File metadata and controls
203 lines (170 loc) · 9.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import inspect
import textwrap
from typing import Callable
import ujson
import dspy
from dspy.adapters.utils import get_field_description_string
from dspy.predict.predict import Prediction
from dspy.signatures import InputField, OutputField, Signature
from .predict import Module
class OfferFeedback(Signature):
"""
In the discussion, assign blame to each module that contributed to the final reward being below the threshold, if
any. Then, prescribe concrete advice of how the module should act on its future input when we retry the process, if
it were to receive the same or similar inputs. If a module is not to blame, the advice should be N/A.
The module will not see its own history, so it needs to rely on entirely concrete and actionable advice from you
to avoid the same mistake on the same or similar inputs.
"""
program_code: str = InputField(desc="The code of the program that we are analyzing")
modules_defn: str = InputField(desc="The definition of each module in the program, including its I/O")
program_inputs: str = InputField(desc="The inputs to the program that we are analyzing")
program_trajectory: str = InputField(desc="The trajectory of the program's execution, showing each module's I/O")
program_outputs: str = InputField(desc="The outputs of the program that we are analyzing")
reward_code: str = InputField(desc="The code of the reward function that we are analyzing")
target_threshold: float = InputField(desc="The target threshold for the reward function")
reward_value: float = InputField(desc="The reward value assigned to the program's outputs")
module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice")
discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did")
advice: dict[str, str] = OutputField(
desc="For each module, describe very concretely, in this order: the specific scenarios in which it has made "
"mistakes in the past and what each mistake was, followed by what it should do differently in that kind of"
"scenario in the future. If the module is not to blame, write N/A."
)
class Refine(Module):
"""
Refines a module by running it up to `N` times with different temperatures and returns the best prediction, as defined by the reward_fn, or the first prediction that passes the threshold. After each attempt (except the final one), `Refine` automatically generates detailed feedback about the module's performance and uses this feedback as hints for subsequent runs, creating an iterative refinement process.
Example:
```python
import dspy
# Use a chain-of-thought QA module as the base
qa = dspy.ChainOfThought("question -> answer")
# Define a reward function that checks for one-word answers
def one_word_answer(args, pred):
return 1.0 if len(pred.answer.split()) == 1 else 0.0
# Create the refined module
best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)
# Use the refined module
result = best_of_3(question="What is the capital of Belgium?").answer
# Returns: Brussels
```
By default, `Refine` will try to run the base module up to N times until the threshold is met. If the module encounters an error, it will keep going up to N failed attempts. You can adjust this behavior with the `fail_count` argument to control the number of computation attempts allowed before raising an error.
"""
def __init__(
self,
module: Module,
N: int, # noqa: N803
reward_fn: Callable[[dict, Prediction], float],
threshold: float,
fail_count: int | None = None,
):
self.module = module
self.reward_fn = lambda *args: reward_fn(*args) # to prevent this from becoming a parameter
self.threshold = threshold
self.N = N
self.fail_count = fail_count or N # default to N if fail_count is not provided
self.module_code = inspect.getsource(module.__class__)
try:
self.reward_fn_code = inspect.getsource(reward_fn)
except TypeError:
self.reward_fn_code = inspect.getsource(reward_fn.__class__)
def forward(self, **kwargs):
lm = self.module.get_lm() or dspy.settings.lm
temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
temps = list(dict.fromkeys(temps))[: self.N]
best_pred, best_trace, best_reward = None, None, -float("inf")
advice = None
adapter = dspy.settings.adapter or dspy.ChatAdapter()
for idx, t in enumerate(temps):
lm_ = lm.copy(temperature=t)
mod = self.module.deepcopy()
mod.set_lm(lm_)
predictor2name = {predictor: name for name, predictor in mod.named_predictors()}
signature2name = {predictor.signature: name for name, predictor in mod.named_predictors()}
module_names = [name for name, _ in mod.named_predictors()]
try:
with dspy.context(trace=[]):
if not advice:
outputs = mod(**kwargs)
else:
class WrapperAdapter(adapter.__class__):
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs["hint_"] = advice.get(signature2name[signature], "N/A") # noqa: B023
signature = signature.append(
"hint_", InputField(desc="A hint to the module from an earlier run")
)
return adapter(lm, lm_kwargs, signature, demos, inputs)
with dspy.context(adapter=WrapperAdapter()):
outputs = mod(**kwargs)
trace = dspy.settings.trace.copy()
# TODO: Remove the hint from the trace, if it's there.
# NOTE: Not including the trace of reward_fn.
reward = self.reward_fn(kwargs, outputs)
if reward > best_reward:
best_reward, best_pred, best_trace = reward, outputs, trace
if self.threshold is not None and reward >= self.threshold:
break
if idx == self.N - 1:
break
modules = {"program_code": self.module_code, "modules_defn": inspect_modules(mod)}
trajectory = [{"module_name": predictor2name[p], "inputs": i, "outputs": dict(o)} for p, i, o in trace]
trajectory = {
"program_inputs": kwargs,
"program_trajectory": trajectory,
"program_outputs": dict(outputs),
}
reward = {
"reward_code": self.reward_fn_code,
"target_threshold": self.threshold,
"reward_value": reward,
}
advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
# advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
# only dumps if it's a list or dict
advise_kwargs = {
k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
for k, v in advise_kwargs.items()
}
advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
# print(f"Advice for each module: {advice}")
except Exception as e:
print(f"Refine: Attempt failed with temperature {t}: {e}")
if idx > self.fail_count:
raise e
self.fail_count -= 1
if best_trace:
dspy.settings.trace.extend(best_trace)
return best_pred
def inspect_modules(program):
separator = "-" * 80
output = [separator]
for _, (name, predictor) in enumerate(program.named_predictors()):
signature = predictor.signature
instructions = textwrap.dedent(signature.instructions)
instructions = ("\n" + "\t" * 2).join([""] + instructions.splitlines())
output.append(f"Module {name}")
output.append("\n\tInput Fields:")
output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.input_fields).splitlines()))
output.append("\tOutput Fields:")
output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.output_fields).splitlines()))
output.append(f"\tOriginal Instructions: {instructions}")
output.append(separator)
return "\n".join([o.strip("\n") for o in output])
def recursive_mask(o):
# If the object is already serializable, return it.
try:
ujson.dumps(o)
return o
except TypeError:
pass
# If it's a dictionary, apply recursively to its values.
if isinstance(o, dict):
return {k: recursive_mask(v) for k, v in o.items()}
# If it's a list, apply recursively.
elif isinstance(o, list):
return [recursive_mask(v) for v in o]
# If it's a tuple, apply recursively.
elif isinstance(o, tuple):
return tuple(recursive_mask(v) for v in o)
# Otherwise, replace it with a placeholder string (or use repr(o)).
else:
return f"<non-serializable: {type(o).__name__}>"