@@ -137,13 +137,9 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
137
137
ACL_CHECK (aclDestroyTensor (acl_dst));
138
138
}
139
139
140
- void aclnn_concat (ggml_backend_cann_context& ctx, aclTensor *acl_src0 ,
141
- aclTensor *acl_src1, aclTensor * acl_dst, int64_t concat_dim,
140
+ void aclnn_concat (ggml_backend_cann_context& ctx, aclTensorList* tensorList ,
141
+ aclTensor* acl_dst, int64_t concat_dim,
142
142
ggml_tensor* bind_tensor) {
143
-
144
- aclTensor* tensors[] = {acl_src0, acl_src1};
145
- aclTensorList* tensorList = aclCreateTensorList (tensors, 2 );
146
-
147
143
uint64_t workspaceSize = 0 ;
148
144
aclOpExecutor* executor;
149
145
void * workspaceAddr = nullptr ;
@@ -157,12 +153,6 @@ void aclnn_concat(ggml_backend_cann_context& ctx, aclTensor *acl_src0,
157
153
158
154
aclrtStream main_stream = ctx.stream ();
159
155
ACL_CHECK (aclnnCat (workspaceAddr, workspaceSize, executor, main_stream));
160
-
161
- // ACL_CHECK(aclDestroyTensor(acl_src0));
162
- // ACL_CHECK(aclDestroyTensor(acl_src1));
163
- ACL_CHECK (aclDestroyTensorList (tensorList));
164
- ACL_CHECK (aclDestroyTensor (acl_dst));
165
-
166
156
}
167
157
168
158
void ggml_cann_concat (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -173,14 +163,11 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
173
163
aclTensor* acl_dst = create_acl_tensor (dst);
174
164
175
165
int64_t concat_dim = 1 ;
166
+ aclTensor* tensors[] = {acl_src0, acl_src1};
167
+ aclTensorList* tensorList = aclCreateTensorList (tensors, 2 );
168
+ aclnn_concat (ctx, tensorList, acl_dst, concat_dim, dst);
176
169
177
- aclnn_concat (ctx, acl_src0, acl_src1, acl_dst, concat_dim, dst);
178
-
179
- // release acl_src0, acl_src1 in aclnn_concat
180
- // ACL_CHECK(aclDestroyTensor(acl_src0));
181
- // ACL_CHECK(aclDestroyTensor(acl_src1));
182
- // ->
183
- // ACL_CHECK(aclDestroyTensorList(tensorList));
170
+ ACL_CHECK (aclDestroyTensorList (tensorList));
184
171
ACL_CHECK (aclDestroyTensor (acl_dst));
185
172
}
186
173
@@ -1331,9 +1318,12 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* d
1331
1318
// concat
1332
1319
int64_t concat_dim = 3 ;
1333
1320
aclTensor* acl_dst = create_acl_tensor (dst);
1334
- aclnn_concat (ctx, tmp_cos_tensor, tmp_sin_tensor, acl_dst, concat_dim, dst);
1321
+ aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
1322
+ aclTensorList* tensorList = aclCreateTensorList (tensors, 2 );
1323
+ aclnn_concat (ctx, tensorList, acl_dst, concat_dim, dst);
1335
1324
1336
1325
// release
1326
+ ACL_CHECK (aclDestroyTensorList (tensorList));
1337
1327
ACL_CHECK (aclDestroyTensor (acl_src));
1338
1328
ACL_CHECK (aclDestroyTensor (tmp_arange_tensor));
1339
1329
ACL_CHECK (aclDestroyTensor (tmp_permute_tenosr));
0 commit comments