Skip to content

Commit 284c95c

Browse files
committed
Add option to plot ratios, add option to choose lin or log as scale for A
1 parent fc68180 commit 284c95c

1 file changed

Lines changed: 105 additions & 148 deletions

File tree

pdfplotter/pdf_set_nuclear.py

Lines changed: 105 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ def plot_A_dep_3d(
177177
Q2: float | None = None,
178178
x_lines: float | list[float] | None = None,
179179
colors: list = [],
180-
logA: bool = True,
180+
A_scale: Literal["log", "linlog", "lin"] = "log",
181181
plot_uncertainty: bool = True,
182-
plot_ratio: bool = False,
182+
ratio_to: PDFSet | None = None,
183183
pdf_label: Literal["ylabel", "annotate"] = "annotate",
184184
A_label: Literal["legend", "ticks", "both"] = "ticks",
185185
proj_type: Literal["ortho", "persp"] = "persp",
@@ -216,7 +216,7 @@ def plot_A_dep_3d(
216216
If True, plot the ratio of the PDFs to the Proton PDF, by default False
217217
pdf_label : str, optional
218218
The label for the PDF, by default "annotate". If "ylabel", it is used in ax.set_title(). If "annotate", the label is set as an annotation.
219-
A_label:
219+
A_label:
220220
If "ticks", the values for A are chosen as z-ticks. If "legend", a legend is plottet. if "both" both is realised
221221
kwargs_theory : dict[str, Any] | list[dict[str, Any] | None], optional
222222
The keyword arguments to pass to the plot function for the central PDF, by default {}.
@@ -229,7 +229,7 @@ def plot_A_dep_3d(
229229
kwargs_xlabel : dict[str, Any] | list[dict[str, Any] | None], optional
230230
The keyword arguments to pass to the xlabel function, by default {}.
231231
kwargs_ylabel : dict[str, Any] | list[dict[str, Any] | None], optional
232-
The keyword arguments to pass to the zlabel function, the A-axis, by default {}.
232+
The keyword arguments to pass to the zlabel function, the A-axis, by default {}.
233233
kwargs_zlabel : dict[str, Any] | list[dict[str, Any] | None], optional
234234
The keyword arguments to pass to the ylabel function, the f(x,Q)-axis,, by default {}.
235235
kwargs_title : dict[str, Any], optional
@@ -268,9 +268,6 @@ def plot_A_dep_3d(
268268
if not isinstance(A, list):
269269
A = [A]
270270

271-
if 1 not in A and plot_ratio:
272-
raise ValueError("Please pass A=1 if you want to plot the ratio to Proton.")
273-
274271
if isinstance(observables, np.ndarray):
275272
observables = list(observables.flatten())
276273

@@ -287,20 +284,13 @@ def plot_A_dep_3d(
287284
for i, (obs_i, ax_i) in enumerate(zip(observables, ax.flat)):
288285

289286
for j, (A_j, col_j) in enumerate(zip(A, colors)):
290-
if not plot_ratio:
291-
z_lower, z_upper = self.get(A=A_j).get_uncertainties(
292-
observable=obs_i, x=x, Q=Q, Q2=Q2
293-
)
294-
else:
295-
z_lower, z_upper = self.get(A=A_j).get_uncertainties(
296-
observable=obs_i, x=x, Q=Q, Q2=Q2
297-
)
298-
z_lower = z_lower / self.get(A=1).get_central(
299-
observable=obs_i, x=x, Q=Q, Q2=Q2
300-
)
301-
z_upper = z_upper / self.get(A=1).get_central(
302-
observable=obs_i, x=x, Q=Q, Q2=Q2
303-
)
287+
z_upper = self.get(A=A_j).get_uncertainties(
288+
observable=obs_i, x=x, Q=Q, Q2=Q2, ratio_to=ratio_to
289+
)[0]
290+
z_lower = [k if k>0 else 0 for k in self.get(A=A_j).get_uncertainties(
291+
observable=obs_i, x=x, Q=Q, Q2=Q2, ratio_to=ratio_to
292+
)[1]]
293+
304294
kwargs_default = {
305295
"color": col_j,
306296
"label": f"A={A_j}",
@@ -317,50 +307,17 @@ def plot_A_dep_3d(
317307
kwargs_theory,
318308
i=j,
319309
)
320-
if logA:
321-
if plot_ratio:
322-
ax_i.plot(
323-
np.log10(x),
324-
np.log10(len(x) * [A_j]),
325-
self.get(A=A_j).get_central(
326-
x=x, Q=Q, Q2=Q2, observable=obs_i
327-
)
328-
/ self.get(A=1).get_central(
329-
x=x, Q=Q, Q2=Q2, observable=obs_i
330-
),
331-
**kwargs,
332-
)
333-
else:
334-
ax_i.plot(
335-
np.log10(x),
336-
np.log10(len(x) * [A_j]),
337-
self.get(A=A_j).get_central(
338-
x=x, Q=Q, Q2=Q2, observable=obs_i
339-
),
340-
**kwargs,
341-
)
342-
else:
343-
if plot_ratio:
344-
ax_i.plot(
345-
np.log10(x),
346-
len(x) * [A_j],
347-
self.get(A=A_j).get_central(
348-
x=x, Q=Q, Q2=Q2, observable=obs_i
349-
)
350-
/ self.get(A=1).get_central(
351-
x=x, Q=Q, Q2=Q2, observable=obs_i
352-
),
353-
**kwargs,
354-
)
355-
else:
356-
ax_i.plot(
357-
np.log10(x),
358-
len(x) * [A_j],
359-
self.get(A=A_j).get_central(
360-
x=x, Q=Q, Q2=Q2, observable=obs_i
361-
),
362-
**kwargs,
363-
)
310+
if A_scale == "log":
311+
Aj_arr = np.log10(len(x) * [A_j])
312+
if A_scale == "lin":
313+
Aj_arr = len(x) * [A_j]
314+
ax_i.plot(
315+
np.log10(x),
316+
Aj_arr,
317+
[k if k>0 else 0 for k in self.get(A=A_j).get_central(x=x, Q=Q, Q2=Q2, observable=obs_i,ratio_to=ratio_to)],
318+
**kwargs,
319+
)
320+
364321
if plot_uncertainty:
365322
kwargs_uncertainty_default = {
366323
"color": col_j,
@@ -381,41 +338,26 @@ def plot_A_dep_3d(
381338
vertices = []
382339
z_lower = np.array(z_lower)
383340
z_upper = np.array(z_upper)
384-
if not logA:
385-
386-
for xi, ai, zl, zu in zip(
387-
np.log10(x), np.ones(len(x)) * A_j, z_lower, z_upper
388-
):
389-
vertices.append([xi, ai, zl])
390-
for xi, ai, zl, zu in reversed(
391-
list(
392-
zip(
393-
np.log10(x), np.ones(len(x)) * A_j, z_lower, z_upper
394-
)
395-
)
396-
):
397-
vertices.append([xi, ai, zu])
398341

399-
else:
400-
for xi, ai, zl, zu in zip(
401-
np.log10(x),
402-
np.ones(len(x)) * np.log10(A_j),
403-
z_lower,
404-
z_upper,
405-
):
406-
vertices.append([xi, ai, zl])
407-
408-
for xi, ai, zl, zu in reversed(
409-
list(
410-
zip(
411-
np.log10(x),
412-
np.ones(len(x)) * np.log10(A_j),
413-
z_lower,
414-
z_upper,
415-
)
342+
for xi, ai, zl, zu in zip(
343+
np.log10(x),
344+
Aj_arr,
345+
z_lower,
346+
z_upper,
347+
):
348+
vertices.append([xi, ai, zl])
349+
350+
for xi, ai, zl, zu in reversed(
351+
list(
352+
zip(
353+
np.log10(x),
354+
Aj_arr,
355+
z_lower,
356+
z_upper,
416357
)
417-
):
418-
vertices.append([xi, ai, zu])
358+
)
359+
):
360+
vertices.append([xi, ai, zu])
419361
poly = Poly3DCollection([vertices], **kwargs)
420362
ax_i.add_collection3d(poly)
421363

@@ -434,16 +376,9 @@ def plot_A_dep_3d(
434376
kwargs_uncertainty_edges,
435377
i=j,
436378
)
437-
if not logA:
438-
ax_i.plot(np.log10(x), len(x) * [A_j], z_upper, **kwargs)
439-
ax_i.plot(np.log10(x), len(x) * [A_j], z_lower, **kwargs)
440-
else:
441-
ax_i.plot(
442-
np.log10(x), len(x) * [np.log10(A_j)], z_upper, **kwargs
443-
)
444-
ax_i.plot(
445-
np.log10(x), len(x) * [np.log10(A_j)], z_lower, **kwargs
446-
)
379+
380+
ax_i.plot(np.log10(x), Aj_arr, z_upper, **kwargs)
381+
ax_i.plot(np.log10(x), Aj_arr, z_lower, **kwargs)
447382

448383
centrals = {}
449384
if x_lines is not None:
@@ -471,53 +406,63 @@ def plot_A_dep_3d(
471406
i=k,
472407
)
473408
for a in A:
474-
if not plot_ratio:
475-
if x_line not in centrals.keys():
476-
centrals[x_line] = [
477-
self.get(A=a).get_central(
478-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
479-
)
480-
]
481-
else:
482-
centrals[x_line].append(
483-
self.get(A=a).get_central(
484-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
485-
)
409+
410+
if x_line not in centrals.keys():
411+
centrals[x_line] = [
412+
self.get(A=a).get_central(
413+
x=x_line,
414+
Q=Q,
415+
Q2=Q2,
416+
observable=obs_i,
417+
ratio_to=ratio_to,
486418
)
419+
]
487420
else:
488-
if x_line not in centrals.keys():
489-
centrals[x_line] = [
490-
self.get(A=a).get_central(
491-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
492-
)
493-
/ self.get(A=1).get_central(
494-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
495-
)
496-
]
497-
else:
498-
centrals[x_line].append(
499-
self.get(A=a).get_central(
500-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
501-
)
502-
/ self.get(A=1).get_central(
503-
x=x_line, Q=Q, Q2=Q2, observable=obs_i
504-
)
421+
centrals[x_line].append(
422+
self.get(A=a).get_central(
423+
x=x_line,
424+
Q=Q,
425+
Q2=Q2,
426+
observable=obs_i,
427+
ratio_to=ratio_to,
505428
)
506-
if logA:
429+
)
430+
if A_scale == "log":
507431
ax_i.plot(
508432
np.ones(len(A)) * np.log10(x_line),
509433
np.log10(A),
510434
centrals[x_line],
511435
**kwargs,
512436
)
513-
else:
437+
ax_i.plot(
438+
[np.log10(x_line),np.log10(x_line)],
439+
[np.log10(A[0]),np.log10(A[0])],
440+
[0,self.get(A=A[0]).get_central(
441+
x=x_line,
442+
Q=Q,
443+
Q2=Q2,
444+
observable=obs_i,
445+
ratio_to=ratio_to,
446+
)],
447+
**kwargs,)
448+
elif A_scale == "lin":
514449
ax_i.plot(
515450
np.ones(len(A)) * np.log10(x_line),
516451
A,
517452
centrals[x_line],
518453
**kwargs,
519454
)
520-
455+
ax_i.plot(
456+
[np.log10(x_line),np.log10(x_line)],
457+
[A[0],A[0]],
458+
[0,self.get(A=A[0]).get_central(
459+
x=x_line,
460+
Q=Q,
461+
Q2=Q2,
462+
observable=obs_i,
463+
ratio_to=ratio_to,
464+
)],
465+
**kwargs,)
521466
if pdf_label == "annotate":
522467
kwargs_annotation_default = {
523468
"fontsize": 12,
@@ -546,10 +491,16 @@ def plot_A_dep_3d(
546491
ax_i.annotate(f"${util.to_str(obs_i, Q=Q,Q2=Q2)}$", **kwargs_n)
547492

548493
if pdf_label == "ylabel":
549-
kwargs_ylabel_default = {
550-
"fontsize": 14,
551-
"zlabel": f"${util.to_str(obs_i,Q=Q,Q2=Q2)}$",
552-
}
494+
if ratio_to:
495+
kwargs_ylabel_default = {
496+
"fontsize": 14,
497+
"zlabel": f"${util.to_str(obs_i,Q=Q,Q2=Q2,R=True)}$",
498+
}
499+
else:
500+
kwargs_ylabel_default = {
501+
"fontsize": 14,
502+
"zlabel": f"${util.to_str(obs_i,Q=Q,Q2=Q2,R=False)}$",
503+
}
553504
if not isinstance(kwargs_ylabel, list):
554505
kwargs = update_kwargs(
555506
kwargs_ylabel_default,
@@ -577,14 +528,16 @@ def plot_A_dep_3d(
577528
boxstyle="round, pad=0.2",
578529
),
579530
}
580-
kwargs_n = update_kwargs(kwargs_annotation_default, kwargs_annotation, i=i)
531+
kwargs_n = update_kwargs(
532+
kwargs_annotation_default, kwargs_annotation, i=i
533+
)
581534

582535
ax_i.annotate(f"${util.to_str(obs_i, Q=Q,Q2=Q2)}$", **kwargs_n)
583536
ax_i.xaxis.set_major_formatter(mticker.FuncFormatter(log_tick_formatter))
584537
if A_label == "ticks" or A_label == "both":
585-
if logA:
538+
if A_scale == "log":
586539
ax_i.set_yticks(np.log10(A), A)
587-
else:
540+
elif A_scale == "lin":
588541
ax_i.set_yticks(A, A)
589542
kwargs_zlabel_default = {
590543
"fontsize": 14,
@@ -642,7 +595,11 @@ def plot_A_dep_3d(
642595
)
643596
ax_i.set_xlabel(**kwargs)
644597
# , np.log10(x[-1]))
645-
ax_i.set_zlim(ax_i.get_zlim()[1] * 0.02)
598+
if not ratio_to:
599+
ax_i.set_zlim(ax_i.get_zlim()[1] * 0.02)
600+
else:
601+
ax_i.set_zlim(ax_i.get_zlim()[0], 2*np.median(self.get(A=A_j).get_central(x=x, Q=Q, Q2=Q2, observable=obs_i,ratio_to=ratio_to)))
602+
ax_i.set_zlim(ax_i.get_zlim()[1] * 0.02)
646603
# ax_i.yaxis._axinfo["grid"]["linewidth"] = 0
647604
ax_i.set_proj_type(proj_type)
648605
ax_i.view_init(*view_init[i] if isinstance(view_init, list) else view_init)

0 commit comments

Comments
 (0)