@@ -23,6 +23,7 @@ class SolverCG : public Solver<howmany, n_str> {
2323 RealArray s_real;
2424 RealArray d_real;
2525 RealArray rnew_real;
26+ double alpha_warm{0.1 };
2627
2728 void internalSolve ();
2829 void LineSearchSecant ();
@@ -64,6 +65,7 @@ void SolverCG<howmany, n_str>::internalSolve()
6465 printf (" \n # Start FANS - Conjugate Gradient Solver \n " );
6566
6667 bool islinear = this ->matmanager ->all_linear ;
68+ alpha_warm = 0.1 ;
6769
6870 s_real.setZero ();
6971 d_real.setZero ();
@@ -117,33 +119,44 @@ void SolverCG<howmany, n_str>::internalSolve()
117119template <int howmany, int n_str>
118120void SolverCG<howmany, n_str>::LineSearchSecant()
119121{
120- double err = 10.0 ;
121- int MaxIter = 5 ;
122- double tol = 1e-2 ;
123- int _iter = 0 ;
124- double alpha_new = 0.0001 ;
125- double alpha_old = 0 ;
126-
127- double r1pd;
128- double rpd = dotProduct (v_r_real, d_real);
122+ double err = 10.0 ;
123+ const int MaxIter = this ->reader .ls_max_iter ;
124+ const double tol = this ->reader .ls_tol ;
125+ int _iter = 0 ;
126+ double alpha_prev = 0.0 ;
127+ double alpha_curr = alpha_warm;
129128
130- while (((_iter < MaxIter) && (err > tol))) {
129+ double rpd = dotProduct (v_r_real, d_real);
130+ v_u_real += d_real * alpha_curr;
131+ this ->updateMixedBC ();
132+ this ->template compute_residual <0 >(rnew_real, v_u_real);
133+ double r1pd = dotProduct (rnew_real, d_real);
134+
135+ double denom, alpha_next;
136+ while (_iter < MaxIter && err > tol) {
137+ denom = r1pd - rpd;
138+ if (fabs (denom) < 1e-14 * (fabs (r1pd) + fabs (rpd)))
139+ break ;
140+
141+ alpha_next = alpha_curr - r1pd * (alpha_curr - alpha_prev) / denom;
142+ if (alpha_next <= 0.0 )
143+ alpha_next = 0.5 * (alpha_prev + alpha_curr);
144+ err = fabs (alpha_next - alpha_curr);
145+
146+ v_u_real += d_real * (alpha_next - alpha_curr);
147+ alpha_prev = alpha_curr;
148+ rpd = r1pd;
149+ alpha_curr = alpha_next;
150+ _iter++;
131151
132- v_u_real += d_real * (alpha_new - alpha_old);
133152 this ->updateMixedBC ();
134153 this ->template compute_residual <0 >(rnew_real, v_u_real);
135154 r1pd = dotProduct (rnew_real, d_real);
136-
137- alpha_old = alpha_new;
138- alpha_new *= rpd / (rpd - r1pd);
139-
140- err = fabs (alpha_new - alpha_old);
141- _iter++;
142155 }
143- v_u_real += d_real * (alpha_new - alpha_old) ;
144- v_r_real = rnew_real;
156+ alpha_warm = (_iter == MaxIter && err > tol) ? 0.1 : alpha_curr ;
157+ v_r_real = rnew_real;
145158 if (this ->world_rank == 0 )
146- printf (" line search iter %i, alpha %f - error %e - " , _iter, alpha_new , err);
159+ printf (" line search iter %i, alpha %f - error %e - " , _iter, alpha_curr , err);
147160}
148161
149162template <int howmany, int n_str>
0 commit comments