1 // Copyright 2010-2021 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #ifndef OR_TOOLS_SAT_PSEUDO_COSTS_H_
15 #define OR_TOOLS_SAT_PSEUDO_COSTS_H_
16 
17 #include <vector>
18 
19 #include "ortools/base/strong_vector.h"
20 #include "ortools/sat/integer.h"
21 #include "ortools/sat/util.h"
22 
23 namespace operations_research {
24 namespace sat {
25 
26 // Pseudo cost of a variable is measured as average observed change in the
27 // objective bounds per unit change in the variable bounds.
28 class PseudoCosts {
29  public:
30   // Helper struct to get information relavant for pseudo costs from branching
31   // decisions.
32   struct VariableBoundChange {
33     IntegerVariable var = kNoIntegerVariable;
34     IntegerValue lower_bound_change = IntegerValue(0);
35   };
36   explicit PseudoCosts(Model* model);
37 
38   // Updates the pseudo costs for the given decision.
39   void UpdateCost(const std::vector<VariableBoundChange>& bound_changes,
40                   IntegerValue obj_bound_improvement);
41 
42   // Returns the variable with best reliable pseudo cost that is not fixed.
43   IntegerVariable GetBestDecisionVar();
44 
45   // Returns the pseudo cost of given variable. Currently used for testing only.
46   double GetCost(IntegerVariable var) const {
47     CHECK_LT(var, pseudo_costs_.size());
48     return pseudo_costs_[var].CurrentAverage();
49   }
50 
51   // Returns the number of recordings of given variable. Currently used for
52   // testing only.
53   int GetRecordings(IntegerVariable var) const {
54     CHECK_LT(var, pseudo_costs_.size());
55     return pseudo_costs_[var].NumRecords();
56   }
57 
58  private:
59   // Updates the cost of a given variable.
60   void UpdateCostForVar(IntegerVariable var, double new_cost);
61 
62   // Reference of integer trail to access the current bounds of variables.
63   const IntegerTrail& integer_trail_;
64 
65   const SatParameters& parameters_;
66 
67   absl::StrongVector<IntegerVariable, IncrementalAverage> pseudo_costs_;
68 };
69 
70 // Returns extracted information to update pseudo costs from the given
71 // branching decision.
72 std::vector<PseudoCosts::VariableBoundChange> GetBoundChanges(
73     LiteralIndex decision, Model* model);
74 
75 }  // namespace sat
76 }  // namespace operations_research
77 
78 #endif  // OR_TOOLS_SAT_PSEUDO_COSTS_H_
79