Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {
double upper_bound(ssize_t index) const;
double upper_bound() const;

/// Returns true IFF the index-wise bounds are uniform.
bool uniform_index_wise_bounds() const { return uniform_index_wise_bounds_; };

// Clip value in a given state to fall within upper_bound and lower_bound
// in a given index.
bool clip_and_set_value(State& state, ssize_t index, double value) const;
Expand All @@ -123,6 +126,9 @@ class NumberNode : public ArrayOutputMixin<ArrayNode>, public DecisionNode {

std::vector<double> lower_bounds_;
std::vector<double> upper_bounds_;

// Indicator variable that the index-wise bounds are uniform.
const bool uniform_index_wise_bounds_;
};

/// A contiguous block of integer numbers.
Expand Down
3 changes: 2 additions & 1 deletion dwave/optimization/src/nodes/numbers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ NumberNode::NumberNode(std::span<const ssize_t> shape, std::vector<double> lower
min_(get_extreme_index_wise_bound<false>(lower_bound)),
max_(get_extreme_index_wise_bound<true>(upper_bound)),
lower_bounds_(std::move(lower_bound)),
upper_bounds_(std::move(upper_bound)) {
upper_bounds_(std::move(upper_bound)),
uniform_index_wise_bounds_(lower_bounds_.size() == 1 && upper_bounds_.size() == 1) {
if ((shape.size() > 0) && (shape[0] < 0)) {
throw std::invalid_argument("Number array cannot have dynamic size.");
}
Expand Down
15 changes: 15 additions & 0 deletions tests/cpp/nodes/test_numbers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TEST_CASE("BinaryNode") {
CHECK(bnode_ptr->upper_bound() == 1.0);
CHECK(bnode_ptr->lower_bound(0) == 0.0);
CHECK(bnode_ptr->upper_bound(1) == 1.0);
CHECK(bnode_ptr->uniform_index_wise_bounds());
}
}

Expand All @@ -62,6 +63,7 @@ TEST_CASE("BinaryNode") {
CHECK(ptr->lower_bound(i) == 0);
CHECK(ptr->upper_bound(i) == 1);
}
CHECK(ptr->uniform_index_wise_bounds());
}

WHEN("We create a state using the default value") {
Expand Down Expand Up @@ -181,6 +183,7 @@ TEST_CASE("BinaryNode") {
CHECK(ptr->lower_bound(i) == 0);
CHECK(ptr->upper_bound(i) == 1);
}
CHECK(ptr->uniform_index_wise_bounds());
}

WHEN("We create a state using the default value") {
Expand Down Expand Up @@ -305,6 +308,7 @@ TEST_CASE("BinaryNode") {
CHECK(bnode_ptr->upper_bound(0) == 1.0);
CHECK(bnode_ptr->upper_bound(1) == 1.0);
CHECK(bnode_ptr->upper_bound(2) == 1.0);
CHECK(!bnode_ptr->uniform_index_wise_bounds());
}

AND_WHEN("We set the state at one of the indices") {
Expand Down Expand Up @@ -406,6 +410,7 @@ TEST_CASE("BinaryNode") {
CHECK(bnode_ptr->lower_bound(1) == 0.0);
CHECK(bnode_ptr->upper_bound(0) == 0.0);
CHECK(bnode_ptr->upper_bound(1) == 1.0);
CHECK(!bnode_ptr->uniform_index_wise_bounds());
}
}

Expand All @@ -423,6 +428,7 @@ TEST_CASE("BinaryNode") {
CHECK(bnode_ptr->lower_bound(1) == 1.0);
CHECK(bnode_ptr->upper_bound(0) == 1.0);
CHECK(bnode_ptr->upper_bound(1) == 1.0);
CHECK(!bnode_ptr->uniform_index_wise_bounds());
}
}

Expand Down Expand Up @@ -457,6 +463,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode_ptr->upper_bound(0) == IntegerNode::default_upper_bound);
CHECK(inode_ptr->lower_bound() == IntegerNode::default_lower_bound);
CHECK(inode_ptr->upper_bound() == IntegerNode::default_upper_bound);
CHECK(inode_ptr->uniform_index_wise_bounds());
}
}

Expand All @@ -477,6 +484,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode.upper_bound(0) == IntegerNode::default_upper_bound);
CHECK(inode.lower_bound() == IntegerNode::default_lower_bound);
CHECK(inode.upper_bound() == IntegerNode::default_upper_bound);
CHECK(inode.uniform_index_wise_bounds());
}
}

Expand All @@ -496,6 +504,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode.upper_bound(0) == 10);
CHECK(inode.lower_bound() == -5);
CHECK(inode.upper_bound() == 10);
CHECK(inode.uniform_index_wise_bounds());
}
}

Expand All @@ -507,6 +516,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode.upper_bound(0) == 10);
CHECK(inode.lower_bound() == IntegerNode::default_lower_bound);
CHECK(inode.upper_bound() == 10);
CHECK(inode.uniform_index_wise_bounds());
}
}

Expand All @@ -518,6 +528,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode1.upper_bound(0) == IntegerNode::default_upper_bound);
CHECK(inode1.lower_bound() == 5);
CHECK(inode1.upper_bound() == IntegerNode::default_upper_bound);
CHECK(inode1.uniform_index_wise_bounds());
}
}

Expand All @@ -529,6 +540,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode1.upper_bound(0) == IntegerNode::default_upper_bound);
CHECK(inode1.lower_bound() == 5);
CHECK(inode1.upper_bound() == IntegerNode::default_upper_bound);
CHECK(inode1.uniform_index_wise_bounds());
}
}

Expand All @@ -551,6 +563,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode_ptr->upper_bound(2) == 7.0);
REQUIRE_THROWS(inode_ptr->lower_bound());
REQUIRE_THROWS(inode_ptr->upper_bound());
CHECK(!inode_ptr->uniform_index_wise_bounds());
}

AND_WHEN("We set the state at one of the indices") {
Expand Down Expand Up @@ -606,6 +619,7 @@ TEST_CASE("IntegerNode") {
CHECK(inode_ptr->upper_bound(1) == 10.0);
CHECK(inode_ptr->lower_bound() == 10.0);
REQUIRE_THROWS(inode_ptr->upper_bound());
CHECK(!inode_ptr->uniform_index_wise_bounds());
}
}

Expand All @@ -631,6 +645,7 @@ TEST_CASE("IntegerNode") {
CHECK(ptr->lower_bound(i) == -10);
CHECK(ptr->upper_bound(i) == IntegerNode::default_upper_bound);
}
CHECK(ptr->uniform_index_wise_bounds());
}

WHEN("We create a state using the default value") {
Expand Down