Skip to content

Commit 715718f

Browse files
committed
hexagon: update buffer support checks to use tensor structure
1 parent 1085c98 commit 715718f

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,15 +1916,16 @@ static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess
19161916
return true;
19171917
}
19181918

1919-
template <typename... _TBuffers>
1919+
template <typename... _TTensor>
19201920
static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess,
1921-
ggml_backend_buffer_t buffer,
1922-
_TBuffers... buffers) {
1923-
if (buffer && (!ggml_backend_buffer_is_hexagon(buffer) || ggml_backend_hexagon_buffer_get_sess(buffer) != sess)) {
1921+
const ggml_tensor * t,
1922+
_TTensor... tensors) {
1923+
if (t && t->buffer &&
1924+
(!ggml_backend_buffer_is_hexagon(t->buffer) || ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess)) {
19241925
return false;
19251926
}
19261927

1927-
return hex_supported_buffer(sess, buffers...);
1928+
return hex_supported_buffer(sess, tensors...);
19281929
}
19291930

19301931
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
@@ -1974,7 +1975,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
19741975
}
19751976

19761977
// src0 & src1 & dst must be mapped to the same session
1977-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, dst->buffer)) {
1978+
if (!hex_supported_buffer(sess, src0, src1, dst)) {
19781979
return false;
19791980
}
19801981

@@ -2022,7 +2023,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
20222023

20232024
// src0 (weights) must be repacked and mapped to the same session
20242025
// src1 & sr2 & dst must be mapped to the same session
2025-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, src2->buffer, dst->buffer)) {
2026+
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
20262027
return false;
20272028
}
20282029

@@ -2056,7 +2057,7 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
20562057
}
20572058

20582059
// src0, src1 & dst must be mapped to the same session
2059-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, dst->buffer)) {
2060+
if (!hex_supported_buffer(sess, src0, src1, dst)) {
20602061
return false;
20612062
}
20622063

@@ -2088,7 +2089,7 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
20882089
}
20892090

20902091
// src0, src1 & dst must be mapped to the same session
2091-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, src2->buffer, dst->buffer)) {
2092+
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
20922093
return false;
20932094
}
20942095

@@ -2115,7 +2116,7 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
21152116
}
21162117

21172118
// src0 & dst must be mapped to the same session
2118-
if (!hex_supported_buffer(sess, src0->buffer, dst->buffer)) {
2119+
if (!hex_supported_buffer(sess, src0, dst)) {
21192120
return false;
21202121
}
21212122

@@ -2152,7 +2153,7 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
21522153
}
21532154

21542155
// src0, src1 & dst must be mapped to the same session
2155-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, dst->buffer)) {
2156+
if (!hex_supported_buffer(sess, src0, src1, dst)) {
21562157
return false;
21572158
}
21582159

@@ -2205,7 +2206,7 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
22052206
}
22062207

22072208
// src0, src1 & dst must be mapped to the same session
2208-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, dst->buffer)) {
2209+
if (!hex_supported_buffer(sess, src0, src1, dst)) {
22092210
return false;
22102211
}
22112212

@@ -2260,7 +2261,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
22602261
}
22612262

22622263
// src0, src1, src2 & dst must be mapped to the same session
2263-
if (!hex_supported_buffer(sess, src0->buffer, src1->buffer, src2->buffer, dst->buffer)) {
2264+
if (!hex_supported_buffer(sess, src0, src1, src2, dst)) {
22642265
return false;
22652266
}
22662267

0 commit comments

Comments
 (0)