Skip to content

Commit 662d21e

Browse files
mdvoretc-intelisanghaomryzhov
authored
[GPU] Introduce ConvertWeightCompressedConv1x1ToMatmul pass (#32224)
### Details: - In the target networks, quantized FC is represented as 1x1 conv with compressed weight. - The new transformation converts compressed-weight 1x1 conv into matmul, allowing them to benefit from FCCompressed optimizations ### Description of the issue: An unusual representation of GEMM with compressed int4 weights using a conv1x1 operation was preventing proper identification and kernel selection for the case. Problematic input graph: <img width="421" height="848" alt="tp_conv_tp_before_marked_2" src="https://github.com/user-attachments/assets/5f393442-da8f-4c27-af84-051edf59ad0c" /> The int4 weight value is marked in red while the GEMM pattern is marked in green. The new transformation replaces the pattern marked in green with a MatMul primitive, which is then recognized in successive transformations (ConvertMatMulToFullyConnected, ConvertFullyConnectedToFullyConnectedCompressed) as part of a FullyConnectedCompressed pattern with accompanying weight dequantization. The resulting desirable output is: <img width="611" height="713" alt="tp_conv_tp_after_marked_2" src="https://github.com/user-attachments/assets/18c66374-3e2b-4481-bf31-3705c2128f02" /> The output conversion (marked in blue) is detected as an optional part of the pattern. ### Tickets: - CVS-172090 --------- Co-authored-by: Mingyu Kim <mingyu.kim@intel.com> Co-authored-by: Mikhail Ryzhov <mikhail.ryzhov@intel.com>
1 parent 6373878 commit 662d21e

File tree

20 files changed

+451
-75
lines changed

20 files changed

+451
-75
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "openvino/core/rt_info.hpp"
1818
#include "openvino/op/constant.hpp"
1919
#include "openvino/op/convert.hpp"
20+
#include "openvino/op/random_uniform.hpp"
2021
#include "openvino/pass/graph_rewrite.hpp"
2122
#include "openvino/pass/pattern/op/op.hpp"
2223
#include "openvino/util/pp.hpp"
@@ -287,7 +288,43 @@ TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>&
287288

288289
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
289290

290-
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);
291+
template <typename... AllowedTypes>
292+
bool is_on_path(const ov::Output<ov::Node>& output) {
293+
auto status = true;
294+
295+
auto root_node = output.get_node();
296+
if (!root_node || root_node->get_output_size() == 0) {
297+
return false;
298+
}
299+
std::deque<ov::Node*> nodes_to_calculate = {root_node};
300+
301+
std::unordered_set<ov::Node*> visited;
302+
while (status && !nodes_to_calculate.empty()) {
303+
auto current_node = nodes_to_calculate.front();
304+
nodes_to_calculate.pop_front();
305+
if (visited.count(current_node)) {
306+
continue;
307+
}
308+
visited.insert(current_node);
309+
// RandomUniform output changes during runtime, so we should not consider it as a constant
310+
if (current_node->get_type_info() == ov::op::v8::RandomUniform::get_type_info_static()) {
311+
return false;
312+
}
313+
314+
if (current_node->get_input_size() == 0 && !(ov::is_type_any_of<AllowedTypes...>(current_node))) {
315+
status = false;
316+
} else {
317+
// not a leaf - continue to search
318+
for (const auto& input_value : current_node->input_values()) {
319+
const auto& input_node = input_value.get_node();
320+
if (!visited.count(input_node)) {
321+
nodes_to_calculate.push_front(input_node);
322+
}
323+
}
324+
}
325+
}
326+
return status;
327+
}
291328

292329
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
293330

src/common/transformations/src/transformations/op_conversions/convert_fc_to_quantized_legacy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ ov::pass::ConvertFCToFCQuantizedLegacy::ConvertFCToFCQuantizedLegacy() {
4545
const auto& multiply_output_shape = multiply.get_partial_shape();
4646

4747
if (*fc_output_shape.rbegin() != *multiply_output_shape.rbegin() ||
48-
!ov::op::util::is_on_constant_path(weights)) {
48+
!ov::op::util::is_on_path<ov::op::v0::Constant>(weights)) {
4949
return false;
5050
}
5151

src/common/transformations/src/transformations/op_conversions/convert_sequences_to_tensor_iterator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,17 @@ bool convert_sequence_to_ti(const std::shared_ptr<ov::Node>& sequence,
128128
const auto squeezed_x = ov::op::util::make_try_fold<ov::op::v0::Squeeze>(X_body_param, axis_1);
129129
const auto squeezed_w = ov::op::util::make_try_fold<ov::op::v0::Squeeze>(W, axis_0);
130130
std::shared_ptr<ov::op::v0::Parameter> W_body_param;
131-
if (!ov::op::util::is_on_constant_path(squeezed_w))
131+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(squeezed_w))
132132
W_body_param = std::make_shared<ov::op::v0::Parameter>(squeezed_w->get_element_type(),
133133
squeezed_w->get_output_partial_shape(0));
134134
const auto squeezed_r = ov::op::util::make_try_fold<ov::op::v0::Squeeze>(R, axis_0);
135135
std::shared_ptr<ov::op::v0::Parameter> R_body_param;
136-
if (!ov::op::util::is_on_constant_path(squeezed_r))
136+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(squeezed_r))
137137
R_body_param = std::make_shared<ov::op::v0::Parameter>(squeezed_r->get_element_type(),
138138
squeezed_r->get_output_partial_shape(0));
139139
const auto squeezed_b = ov::op::util::make_try_fold<ov::op::v0::Squeeze>(B, axis_0);
140140
std::shared_ptr<ov::op::v0::Parameter> B_body_param;
141-
if (!ov::op::util::is_on_constant_path(squeezed_b))
141+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(squeezed_b))
142142
B_body_param = std::make_shared<ov::op::v0::Parameter>(squeezed_b->get_element_type(),
143143
squeezed_b->get_output_partial_shape(0));
144144

src/common/transformations/src/transformations/utils/utils.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "openvino/op/multiply.hpp"
2323
#include "openvino/op/paged_attention.hpp"
2424
#include "openvino/op/parameter.hpp"
25-
#include "openvino/op/random_uniform.hpp"
2625
#include "openvino/op/read_value.hpp"
2726
#include "openvino/op/relu.hpp"
2827
#include "openvino/op/reshape.hpp"
@@ -499,43 +498,6 @@ bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int6
499498
return false;
500499
}
501500

502-
bool is_on_constant_path(const ov::Output<ov::Node>& output) {
503-
auto status = true;
504-
505-
auto root_node = output.get_node();
506-
if (!root_node || root_node->get_output_size() == 0) {
507-
return false;
508-
}
509-
std::deque<ov::Node*> nodes_to_calculate = {root_node};
510-
511-
std::unordered_set<ov::Node*> visited;
512-
while (status && !nodes_to_calculate.empty()) {
513-
auto current_node = nodes_to_calculate.front();
514-
nodes_to_calculate.pop_front();
515-
if (visited.count(current_node)) {
516-
continue;
517-
}
518-
visited.insert(current_node);
519-
// RandomUniform output changes during runtime, so we should not consider it as a constant
520-
if (current_node->get_type_info() == ov::op::v8::RandomUniform::get_type_info_static()) {
521-
return false;
522-
}
523-
524-
if (current_node->get_input_size() == 0 && !ov::is_type<ov::op::v0::Constant>(current_node)) {
525-
status = false;
526-
} else {
527-
// not a leaf - continue to search
528-
for (const auto& input_value : current_node->input_values()) {
529-
const auto& input_node = input_value.get_node();
530-
if (!visited.count(input_node)) {
531-
nodes_to_calculate.push_front(input_node);
532-
}
533-
}
534-
}
535-
}
536-
return status;
537-
}
538-
539501
bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node) {
540502
bool changed = false;
541503

src/core/tests/pattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ TEST(pattern, optional_match_node_with_single_input) {
469469
TestMatcher matcher;
470470
// is_on_const_path
471471
auto param_predicate = [](const Output<Node>& output) {
472-
return !ov::op::util::is_on_constant_path(output);
472+
return !ov::op::util::is_on_path<ov::op::v0::Constant>(output);
473473
};
474474

475475
auto pattern_in_0 = ov::pass::pattern::any_input();

src/inference/src/dev/performance_heuristics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ MemBandwidthPressure mem_bandwidth_pressure_tolerance(const std::shared_ptr<ov::
6262
output.get_partial_shape().is_static()) {
6363
const auto& shapeInput0 = input0.get_shape();
6464
const auto& shapeInput1 = input1.get_shape();
65-
const auto non_const = !ov::op::util::is_on_constant_path(node->input_value(1));
65+
const auto non_const = !ov::op::util::is_on_path<ov::op::v0::Constant>(node->input_value(1));
6666
const auto& shapeOutput = output.get_shape();
6767
const auto dataSizeInput0 =
6868
std::accumulate(shapeInput0.begin(), shapeInput0.end(), size_t(1), std::multiplies<size_t>());

src/plugins/intel_cpu/src/nodes/fullyconnected.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "openvino/core/node.hpp"
4141
#include "openvino/core/type.hpp"
4242
#include "openvino/core/type/element_type.hpp"
43+
#include "openvino/op/constant.hpp"
4344
#include "openvino/runtime/threading/cpu_message.hpp"
4445
#include "ov_ops/fully_connected.hpp"
4546
#include "ov_ops/fully_connected_compressed.hpp"
@@ -111,15 +112,15 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ov::Node>&
111112
}
112113

113114
if (ov::is_type<const ov::op::internal::FullyConnected>(op)) {
114-
if (!ov::op::util::is_on_constant_path(op->input_value(BIAS))) {
115+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(BIAS))) {
115116
errorMessage = "Only Constant operation on 'bias' input is supported";
116117
return false;
117118
}
118119
}
119120

120121
if (ov::is_type<const ov::op::internal::FullyConnectedCompressed>(op)) {
121-
if (!ov::op::util::is_on_constant_path(op->input_value(WEIGHT_SCALES)) ||
122-
!ov::op::util::is_on_constant_path(op->input_value(WEIGHT_ZERO_POINTS))) {
122+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(WEIGHT_SCALES)) ||
123+
!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(WEIGHT_ZERO_POINTS))) {
123124
errorMessage =
124125
"Only Constant operation on 'weight scales', and 'weight zero points' inputs is supported";
125126
return false;

src/plugins/intel_cpu/src/nodes/gathermatmul.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "openvino/core/parallel.hpp"
4242
#include "openvino/core/type.hpp"
4343
#include "openvino/core/type/element_type.hpp"
44+
#include "openvino/op/constant.hpp"
4445
#include "shape_inference/custom/gathermatmul.hpp"
4546
#include "transformations/cpu_opset/common/op/batch_gather_matmul.hpp"
4647
#include "transformations/cpu_opset/common/op/batch_gather_matmul_compressed.hpp"
@@ -253,22 +254,22 @@ bool GatherMatmul::isSupportedOperation(const std::shared_ptr<const ov::Node>& o
253254
}
254255

255256
// Check that weights input (port 1) is constant
256-
if (!ov::op::util::is_on_constant_path(op->input_value(WEIGHTS))) {
257+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(WEIGHTS))) {
257258
errorMessage = "Only constant weights are supported for GatherMatmul operation";
258259
return false;
259260
}
260261

261262
// For compressed variant, check that scales and zero points are constant
262263
if (isBatchGatherMatmulCompressed) {
263264
if (op->get_input_size() > WEIGHT_SCALES) {
264-
if (!ov::op::util::is_on_constant_path(op->input_value(WEIGHT_SCALES))) {
265+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(WEIGHT_SCALES))) {
265266
errorMessage = "Only constant weight scales are supported for GatherMatmul operation";
266267
return false;
267268
}
268269
}
269270

270271
if (op->get_input_size() > WEIGHT_ZERO_POINTS) {
271-
if (!ov::op::util::is_on_constant_path(op->input_value(WEIGHT_ZERO_POINTS))) {
272+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(WEIGHT_ZERO_POINTS))) {
272273
errorMessage = "Only constant weight zero points are supported for GatherMatmul operation";
273274
return false;
274275
}
@@ -280,7 +281,7 @@ bool GatherMatmul::isSupportedOperation(const std::shared_ptr<const ov::Node>& o
280281
const auto& biasInput = op->input_value(BIAS);
281282
// Skip validation if bias is dynamic (empty constant)
282283
if (biasInput.get_element_type() != ov::element::dynamic) {
283-
if (!ov::op::util::is_on_constant_path(biasInput)) {
284+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(biasInput)) {
284285
errorMessage = "Only constant bias is supported for GatherMatmul operation";
285286
return false;
286287
}

src/plugins/intel_cpu/src/nodes/rnn.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "openvino/core/type.hpp"
4444
#include "openvino/core/type/element_type.hpp"
4545
#include "openvino/core/type/element_type_traits.hpp"
46+
#include "openvino/op/constant.hpp"
4647
#include "openvino/op/gru_cell.hpp"
4748
#include "openvino/op/gru_sequence.hpp"
4849
#include "openvino/op/lstm_cell.hpp"
@@ -274,9 +275,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::s
274275
ov::op::v0::RNNCell::get_type_info_static(),
275276
ov::op::v3::GRUCell::get_type_info_static())) {
276277
// Plug-in does not support dynamism on weights.
277-
if (!ov::op::util::is_on_constant_path(op->input_value(2)) ||
278-
!ov::op::util::is_on_constant_path(op->input_value(3)) ||
279-
(op->get_input_size() > 4 && !ov::op::util::is_on_constant_path(op->input_value(4)))) {
278+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(2)) ||
279+
!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(3)) ||
280+
(op->get_input_size() > 4 && !ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(4)))) {
280281
errorMessage = "Node expects constants as W, R, B inputs.";
281282
return false;
282283
}
@@ -286,9 +287,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::s
286287
ov::op::v5::GRUSequence::get_type_info_static(),
287288
ov::op::v5::RNNSequence::get_type_info_static())) {
288289
// Plug-in does not support dynamism on weights.
289-
if (!ov::op::util::is_on_constant_path(op->input_value(3)) ||
290-
!ov::op::util::is_on_constant_path(op->input_value(4)) ||
291-
(op->get_input_size() > 5 && !ov::op::util::is_on_constant_path(op->input_value(5)))) {
290+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(3)) ||
291+
!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(4)) ||
292+
(op->get_input_size() > 5 && !ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(5)))) {
292293
errorMessage = "Node expects constants as W, R, B inputs.";
293294
return false;
294295
}
@@ -302,9 +303,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::s
302303
return false;
303304
}
304305
// Plug-in does not support dynamism on weights.
305-
if (!ov::op::util::is_on_constant_path(op->input_value(4)) ||
306-
!ov::op::util::is_on_constant_path(op->input_value(5)) ||
307-
!ov::op::util::is_on_constant_path(op->input_value(6))) {
306+
if (!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(4)) ||
307+
!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(5)) ||
308+
!ov::op::util::is_on_path<ov::op::v0::Constant>(op->input_value(6))) {
308309
errorMessage = "Node expects static shaped W, R, B inputs.";
309310
return false;
310311
}

src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/batch_gather_matmul.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ void BatchGatherMatmul::validate_and_infer_types() {
5454
", expected at least 4.");
5555

5656
// Check input B is on constant path
57-
NODE_VALIDATION_CHECK(this, ov::op::util::is_on_constant_path(input_value(1)), "Input B must be on constant path.");
57+
NODE_VALIDATION_CHECK(this,
58+
ov::op::util::is_on_path<ov::op::v0::Constant>(input_value(1)),
59+
"Input B must be on constant path.");
5860

5961
const auto& a_shape = get_input_partial_shape(0);
6062
const auto& b_shape = get_input_partial_shape(1);

0 commit comments

Comments
 (0)