1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "RenderScript.h" 19#include "rsCppInternal.h" 20 21using namespace android; 22using namespace RSC; 23 24// ScriptIntrinsicBLAS APIS 25ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e) 26 : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { 27 28} 29 30sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) { 31 return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); 32} 33 34enum RsBlasDataType { 35 SINGLE, 36 DOUBLE, 37 SINGLE_COMPLEX, 38 DOUBLE_COMPLEX 39}; 40 41static RsBlasCall 42setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, 43 int TransA, int TransB, int Side, int Uplo, int Diag, 44 int M, int N, int K, int incX, int incY, int KL, int KU, 45 float alphaF, float betaF, double alphaD, double betaD, 46 float alphaCX, float alphaCY, float betaCX, float betaCY, 47 double alphaZX, double alphaZY, double betaZX, double betaZY 48 ) { 49 RsBlasCall call; 50 memset(&call, 0, sizeof(call)); 51 call.func = func; 52 call.transA = (RsBlasTranspose)TransA; 53 call.transB = (RsBlasTranspose)TransB; 54 call.side = (RsBlasSide)Side; 55 call.uplo = (RsBlasUplo)Uplo; 56 call.diag = (RsBlasDiag)Diag; 57 call.M = M; 58 call.N = N; 59 call.K = K; 60 61 switch (dataType) { 62 case SINGLE: 63 // For Single-precision BLAS. 64 call.alpha.f = alphaF; 65 call.beta.f = betaF; 66 break; 67 case DOUBLE: 68 // For Double-precision BLAS. 69 call.alpha.d = alphaD; 70 call.beta.d = betaD; 71 break; 72 case SINGLE_COMPLEX: 73 // For Single-precision complex BLAS. 74 call.alpha.c.r = alphaCX; 75 call.alpha.c.i = alphaCY; 76 call.beta.c.r = betaCX; 77 call.beta.c.i = betaCY; 78 break; 79 case DOUBLE_COMPLEX: 80 // For Double-precision complex BLAS. 81 call.alpha.z.r = alphaZX; 82 call.alpha.z.i = alphaZY; 83 call.beta.z.r = betaZX; 84 call.beta.z.i = betaZY; 85 break; 86 default: 87 break; 88 } 89 90 call.incX = incX; 91 call.incY = incY; 92 call.KL = KL; 93 call.KU = KU; 94 95 return call; 96} 97 98static void 99nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 100 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 101 float alpha, RsAllocation A, RsAllocation B, 102 float beta, RsAllocation C, int incX, int incY, int KL, int KU) { 103 RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, 104 M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, 105 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 106 RsAllocation in_allocs[3] = {A, B, C}; 107 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, 108 &call, sizeof(call), nullptr, 0)); 109} 110 111 112static void 113nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 114 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 115 double alpha, RsAllocation A, RsAllocation B, 116 double beta, RsAllocation C, int incX, int incY, int KL, int KU) { 117 RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, 118 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, 119 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); 120 RsAllocation in_allocs[3] = {A, B, C}; 121 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, 122 &call, sizeof(call), nullptr, 0)); 123} 124 125static void 126nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 127 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 128 float alphaX, float alphaY, RsAllocation A, RsAllocation B, 129 float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 130 RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 131 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 132 alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); 133 RsAllocation in_allocs[3] = {A, B, C}; 134 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, 135 &call, sizeof(call), nullptr, 0)); 136} 137 138static void 139nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, 140 int TransB, int Side, int Uplo, int Diag, int M, int N, int K, 141 double alphaX, double alphaY, RsAllocation A, RsAllocation B, 142 double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { 143 RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, 144 M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 145 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); 146 RsAllocation in_allocs[3] = {A, B, C}; 147 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, 148 &call, sizeof(call), nullptr, 0)); 149} 150 151 152static void 153nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, 154 RsAllocation A, int a_offset, RsAllocation B, int b_offset, 155 RsAllocation C, int c_offset, int c_mult_int) { 156 RsBlasCall call; 157 memset(&call, 0, sizeof(call)); 158 call.func = RsBlas_bnnm; 159 call.M = M; 160 call.N = N; 161 call.K = K; 162 call.a_offset = a_offset & 0xFF; 163 call.b_offset = b_offset & 0xFF; 164 call.c_offset = c_offset; 165 call.c_mult_int = c_mult_int; 166 167 RsAllocation in_allocs[3] = {A, B, C}; 168 tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, 169 &call, sizeof(call), nullptr, 0)); 170} 171 172/** 173 * Level 2 BLAS 174 */ 175static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A, 176 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) { 177 int M = A->getType()->getY(); 178 int N = A->getType()->getX(); 179 if (!A->getType()->getElement()->isCompatible(e) || 180 !X->getType()->getElement()->isCompatible(e) || 181 !Y->getType()->getElement()->isCompatible(e)) { 182 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 183 } 184 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 185 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 186 } 187 188 if (incX <= 0 || incY <= 0) { 189 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 190 } 191 int expectedXDim = -1, expectedYDim = -1; 192 if (TransA == RsBlasNoTrans) { 193 expectedXDim = 1 + (N - 1) * incX; 194 expectedYDim = 1 + (M - 1) * incY; 195 } else { 196 expectedXDim = 1 + (M - 1) * incX; 197 expectedYDim = 1 + (N - 1) * incY; 198 } 199 if ((int)X->getType()->getX() != expectedXDim || 200 (int)Y->getType()->getX() != expectedYDim) { 201 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); 202 } 203} 204 205void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X, 206 int incX, float beta, sp<Allocation> Y, int incY) { 207 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 208 int M = A->getType()->getY(); 209 int N = A->getType()->getX(); 210 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, 211 TransA, 0, 0, 0, 0, M, N, 0, 212 alpha, A->getID(), X->getID(), 213 beta, Y->getID(), incX, incY, 0, 0); 214} 215 216void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X, 217 int incX, double beta, sp<Allocation> Y, int incY) { 218 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 219 int M = A->getType()->getY(); 220 int N = A->getType()->getX(); 221 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, 222 TransA, 0, 0, 0, 0, M, N, 0, 223 alpha, A->getID(), X->getID(), 224 beta, Y->getID(), incX, incY, 0, 0); 225} 226 227void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X, 228 int incX, Float2 beta, sp<Allocation> Y, int incY) { 229 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 230 int M = A->getType()->getY(); 231 int N = A->getType()->getX(); 232 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, 233 TransA, 0, 0, 0, 0, M, N, 0, 234 alpha.x, alpha.y, A->getID(), X->getID(), 235 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 236} 237 238void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X, 239 int incX, Double2 beta, sp<Allocation> Y, int incY) { 240 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 241 int M = A->getType()->getY(); 242 int N = A->getType()->getX(); 243 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, 244 TransA, 0, 0, 0, 0, M, N, 0, 245 alpha.x, alpha.y, A->getID(), X->getID(), 246 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 247} 248 249void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A, 250 sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) { 251 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 252 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); 253 if (KL < 0 || KU < 0) { 254 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 255 } 256 int M = A->getType()->getY(); 257 int N = A->getType()->getX(); 258 259 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, 260 TransA, 0, 0, 0, 0, M, N, 0, 261 alpha, A->getID(), X->getID(), 262 beta, Y->getID(), incX, incY, KL, KU); 263} 264 265void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A, 266 sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) { 267 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 268 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); 269 if (KL < 0 || KU < 0) { 270 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 271 } 272 int M = A->getType()->getY(); 273 int N = A->getType()->getX(); 274 275 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, 276 TransA, 0, 0, 0, 0, M, N, 0, 277 alpha, A->getID(), X->getID(), 278 beta, Y->getID(), incX, incY, KL, KU); 279} 280 281void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A, 282 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 283 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 284 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); 285 if (KL < 0 || KU < 0) { 286 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 287 } 288 int M = A->getType()->getY(); 289 int N = A->getType()->getX(); 290 291 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, 292 TransA, 0, 0, 0, 0, M, N, 0, 293 alpha.x, alpha.y, A->getID(), X->getID(), 294 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 295} 296 297void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A, 298 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) { 299 // GBMV has the same validation requirements as GEMV + KL and KU >= 0 300 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); 301 if (KL < 0 || KU < 0) { 302 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); 303 } 304 int M = A->getType()->getY(); 305 int N = A->getType()->getX(); 306 307 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, 308 TransA, 0, 0, 0, 0, M, N, 0, 309 alpha.x, alpha.y, A->getID(), X->getID(), 310 beta.x, beta.y, Y->getID(), incX, incY, KL, KU); 311} 312 313static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA, 314 RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) { 315 int N = A->getType()->getY(); 316 if ((int)A->getType()->getX() != N) { 317 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); 318 } 319 if (!A->getType()->getElement()->isCompatible(e) || 320 !X->getType()->getElement()->isCompatible(e)) { 321 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 322 } 323 if (X->getType()->getY() > 1) { 324 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 325 } 326 327 if (incX <= 0) { 328 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 329 } 330 int expectedXDim = 1 + (N - 1) * incX; 331 if ((int)X->getType()->getX() != expectedXDim) { 332 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); 333 } 334} 335 336static int validateTPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA, 337 RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) { 338 if (!Ap->getType()->getElement()->isCompatible(e) || 339 !X->getType()->getElement()->isCompatible(e)) { 340 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 341 } 342 if (X->getType()->getY() > 1) { 343 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 344 } 345 346 if (Ap->getType()->getY() > 1) { 347 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 348 } 349 350 int N = sqrt((double)Ap->getType()->getX() * 2); 351 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 352 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 353 } 354 if (incX <= 0) { 355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 356 } 357 int expectedXDim = 1 + (N - 1) * incX; 358 if ((int)X->getType()->getX() != expectedXDim) { 359 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); 360 } 361 362 return N; 363} 364 365 366void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 367 sp<Allocation> A, sp<Allocation> X, int incX) { 368 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 369 int N = A->getType()->getY(); 370 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, 371 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 372 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 373} 374 375void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 376 sp<Allocation> A, sp<Allocation> X, int incX) { 377 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 378 int N = A->getType()->getY(); 379 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, 380 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 381 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 382} 383 384void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 385 sp<Allocation> A, sp<Allocation> X, int incX) { 386 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 387 int N = A->getType()->getY(); 388 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, 389 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 390 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 391} 392 393void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 394 sp<Allocation> A, sp<Allocation> X, int incX) { 395 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 396 int N = A->getType()->getY(); 397 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, 398 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 399 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 400} 401 402void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 403 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 404 // TBMV has the same requirements as TRMV + K >= 0 405 if (K < 0) { 406 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 407 } 408 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 409 int N = A->getType()->getY(); 410 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, 411 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 412 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 413} 414 415void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 416 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 417 // TBMV has the same requirements as TRMV + K >= 0 418 if (K < 0) { 419 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 420 } 421 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 422 int N = A->getType()->getY(); 423 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, 424 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 425 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 426} 427 428void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 429 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 430 // TBMV has the same requirements as TRMV + K >= 0 431 if (K < 0) { 432 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 433 } 434 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 435 int N = A->getType()->getY(); 436 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, 437 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 438 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 439} 440 441void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 442 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 443 // TBMV has the same requirements as TRMV + K >= 0 444 if (K < 0) { 445 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 446 } 447 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 448 int N = A->getType()->getY(); 449 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, 450 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 451 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 452} 453 454void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 455 sp<Allocation> Ap, sp<Allocation> X, int incX) { 456 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 457 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, 458 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 459 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 460} 461 462void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 463 sp<Allocation> Ap, sp<Allocation> X, int incX) { 464 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 465 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, 466 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 467 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 468} 469 470void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 471 sp<Allocation> Ap, sp<Allocation> X, int incX) { 472 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 473 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, 474 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 475 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 476} 477 478void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 479 sp<Allocation> Ap, sp<Allocation> X, int incX) { 480 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 481 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, 482 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 483 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 484} 485 486void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 487 sp<Allocation> A, sp<Allocation> X, int incX) { 488 // TRSV is the same as TRMV 489 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 490 int N = A->getType()->getY(); 491 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, 492 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 493 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 494} 495 496void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 497 sp<Allocation> A, sp<Allocation> X, int incX) { 498 // TRSV is the same as TRMV 499 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 500 int N = A->getType()->getY(); 501 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, 502 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 503 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 504 505} 506 507void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 508 sp<Allocation> A, sp<Allocation> X, int incX) { 509 // TRSV is the same as TRMV 510 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 511 int N = A->getType()->getY(); 512 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, 513 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 514 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 515 516} 517 518void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 519 sp<Allocation> A, sp<Allocation> X, int incX) { 520 // TRSV is the same as TRMV 521 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 522 int N = A->getType()->getY(); 523 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, 524 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 525 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 526 527} 528 529void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 530 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 531 // TBSV is the same as TRMV + K >= 0 532 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); 533 int N = A->getType()->getY(); 534 if (K < 0) { 535 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 536 } 537 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, 538 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 539 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 540} 541 542void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 543 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 544 // TBSV is the same as TRMV + K >= 0 545 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); 546 int N = A->getType()->getY(); 547 if (K < 0) { 548 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 549 } 550 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, 551 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 552 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 553} 554 555void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 556 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 557 // TBSV is the same as TRMV + K >= 0 558 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); 559 int N = A->getType()->getY(); 560 if (K < 0) { 561 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 562 } 563 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, 564 TransA, 0, 0, Uplo, Diag, 0, N, K, 565 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 566} 567 568void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 569 int K, sp<Allocation> A, sp<Allocation> X, int incX) { 570 // TBSV is the same as TRMV + K >= 0 571 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); 572 int N = A->getType()->getY(); 573 if (K < 0) { 574 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); 575 } 576 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, 577 TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, 578 A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 579} 580 581void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 582 sp<Allocation> Ap, sp<Allocation> X, int incX) { 583 // TPSV is same as TPMV 584 int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); 585 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, 586 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 587 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 588} 589 590void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 591 sp<Allocation> Ap, sp<Allocation> X, int incX) { 592 // TPSV is same as TPMV 593 int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); 594 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, 595 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 596 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); 597} 598 599void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 600 sp<Allocation> Ap, sp<Allocation> X, int incX) { 601 // TPSV is same as TPMV 602 int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 603 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, 604 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 605 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 606} 607 608void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 609 sp<Allocation> Ap, sp<Allocation> X, int incX) { 610 // TPSV is same as TPMV 611 int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); 612 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, 613 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, 614 Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); 615} 616 617/** 618 * Level 2, S and D only 619 */ 620static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A, 621 sp<Allocation> X, sp<Allocation> Y, int incX, int incY) { 622 int N = A->getType()->getY(); 623 if ((int)A->getType()->getX() != N) { 624 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); 625 } 626 if (!A->getType()->getElement()->isCompatible(e) || 627 !X->getType()->getElement()->isCompatible(e) || 628 !Y->getType()->getElement()->isCompatible(e) ) { 629 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 630 } 631 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 632 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 633 } 634 635 if (incX <= 0 || incY <= 0) { 636 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 637 } 638 int expectedXDim = 1 + (N - 1) * incX; 639 if ((int)X->getType()->getX() != expectedXDim) { 640 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 641 } 642 int expectedYDim = 1 + (N - 1) * incY; 643 if ((int)Y->getType()->getX() != expectedYDim) { 644 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); 645 } 646 return N; 647} 648static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap, 649 sp<Allocation> X, int incX, sp<Allocation> Y, int incY) { 650 if (!Ap->getType()->getElement()->isCompatible(e) || 651 !X->getType()->getElement()->isCompatible(e) || 652 !Y->getType()->getElement()->isCompatible(e)) { 653 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 654 } 655 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 656 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 657 } 658 659 if (Ap->getType()->getY() > 1) { 660 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 661 } 662 663 int N = sqrt((double)Ap->getType()->getX() * 2); 664 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 665 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 666 } 667 if (incX <= 0 || incY <= 0) { 668 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 669 } 670 int expectedXDim = 1 + (N - 1) * incX; 671 if ((int)X->getType()->getX() != expectedXDim) { 672 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 673 } 674 int expectedYDim = 1 + (N - 1) * incY; 675 if ((int)Y->getType()->getX() != expectedYDim) { 676 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); 677 } 678 679 return N; 680} 681static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX, 682 sp<Allocation> Y, int incY, sp<Allocation> A) { 683 if (!A->getType()->getElement()->isCompatible(e) || 684 !X->getType()->getElement()->isCompatible(e) || 685 !Y->getType()->getElement()->isCompatible(e) ) { 686 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 687 } 688 689 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 690 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 691 } 692 693 int M = A->getType()->getY(); 694 int N = A->getType()->getX(); 695 696 if (N < 1 || M < 1) { 697 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); 698 } 699 if (incX <= 0 || incY <= 0) { 700 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 701 } 702 int expectedXDim = 1 + (M - 1) * incX; 703 if ((int)X->getType()->getX() != expectedXDim) { 704 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 705 } 706 int expectedYDim = 1 + (N - 1) * incY; 707 if ((int)Y->getType()->getX() != expectedYDim) { 708 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); 709 } 710 711 712} 713static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, 714 sp<Allocation> X, int incX, sp<Allocation> A) { 715 if (!A->getType()->getElement()->isCompatible(e) || 716 !X->getType()->getElement()->isCompatible(e)) { 717 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 718 } 719 720 int N = A->getType()->getX(); 721 722 if (X->getType()->getY() > 1) { 723 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 724 } 725 if (N != (int)A->getType()->getY()) { 726 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 727 } 728 if (incX <= 0) { 729 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 730 } 731 int expectedXDim = 1 + (N - 1) * incX; 732 if ((int)X->getType()->getX() != expectedXDim) { 733 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 734 } 735 return N; 736} 737static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, 738 sp<Allocation> X, int incX, sp<Allocation> Ap) { 739 if (!Ap->getType()->getElement()->isCompatible(e) || 740 !X->getType()->getElement()->isCompatible(e)) { 741 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 742 } 743 if (X->getType()->getY() > 1) { 744 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 745 } 746 747 if (Ap->getType()->getY() > 1) { 748 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 749 } 750 751 int N = sqrt((double)Ap->getType()->getX() * 2); 752 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 753 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 754 } 755 if (incX <= 0) { 756 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 757 } 758 int expectedXDim = 1 + (N - 1) * incX; 759 if ((int)X->getType()->getX() != expectedXDim) { 760 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); 761 } 762 763 return N; 764} 765 766static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X, 767 int incX, sp<Allocation> Y, int incY, sp<Allocation> A) { 768 if (!A->getType()->getElement()->isCompatible(e) || 769 !X->getType()->getElement()->isCompatible(e) || 770 !Y->getType()->getElement()->isCompatible(e)) { 771 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 772 } 773 774 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 775 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 776 } 777 778 int N = A->getType()->getX(); 779 780 if (N != (int)A->getType()->getY()) { 781 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); 782 } 783 if (incX <= 0 || incY <= 0) { 784 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 785 } 786 int expectedXDim = 1 + (N - 1) * incX; 787 int expectedYDim = 1 + (N - 1) * incY; 788 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 789 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); 790 } 791 return N; 792 793} 794static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X, 795 int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) { 796 if (!Ap->getType()->getElement()->isCompatible(e) || 797 !X->getType()->getElement()->isCompatible(e) || 798 !Y->getType()->getElement()->isCompatible(e)) { 799 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 800 } 801 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 802 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 803 } 804 805 if (Ap->getType()->getY() > 1) { 806 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); 807 } 808 809 int N = sqrt((double)Ap->getType()->getX() * 2); 810 if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { 811 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); 812 } 813 if (incX <= 0 || incY <= 0) { 814 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 815 } 816 int expectedXDim = 1 + (N - 1) * incX; 817 int expectedYDim = 1 + (N - 1) * incY; 818 if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { 819 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); 820 } 821 822 return N; 823} 824 825void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X, 826 int incX, float beta, sp<Allocation> Y, int incY) { 827 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 828 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, 829 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 830 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 831} 832 833void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X, 834 int incX, float beta, sp<Allocation> Y, int incY) { 835 // SBMV is the same as SYMV + K >= 0 836 if (K < 0) { 837 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 838 } 839 int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); 840 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, 841 0, 0, 0, Uplo, 0, 0, N, K, alpha, 842 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 843} 844 845void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X, 846 int incX, float beta, sp<Allocation> Y, int incY) { 847 int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); 848 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, 849 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 850 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 851} 852 853void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX, 854 sp<Allocation> Y, int incY, sp<Allocation> A) { 855 int M = A->getType()->getY(); 856 int N = A->getType()->getX(); 857 validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); 858 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, 859 0, 0, 0, 0, 0, M, N, 0, alpha, 860 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 861} 862 863void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 864 int incX, sp<Allocation> A) { 865 int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); 866 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, 867 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 868 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 869} 870 871void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 872 int incX, sp<Allocation> Ap) { 873 int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); 874 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, 875 0, 0, 0, Uplo, 0, 0, N, 0, 876 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 877} 878 879void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX, 880 sp<Allocation> Y, int incY, sp<Allocation> A) { 881 int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); 882 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, 883 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 884 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 885} 886 887void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX, 888 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 889 int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); 890 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, 891 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 892 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 893} 894 895void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X, 896 int incX, double beta, sp<Allocation> Y, int incY) { 897 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 898 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, 899 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 900 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 901} 902 903void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X, 904 int incX, double beta, sp<Allocation> Y, int incY) { 905 // SBMV is the same as SYMV + K >= 0 906 if (K < 0) { 907 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); 908 } 909 int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); 910 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, 911 0, 0, 0, Uplo, 0, 0, N, K, alpha, 912 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 913} 914 915void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X, 916 int incX, double beta, sp<Allocation> Y, int incY) { 917 int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); 918 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, 919 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 920 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); 921} 922 923void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y, 924 int incY, sp<Allocation> A) { 925 int M = A->getType()->getY(); 926 int N = A->getType()->getX(); 927 validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); 928 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, 929 0, 0, 0, 0, 0, M, N, 0, alpha, 930 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); 931} 932 933void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 934 int incX, sp<Allocation> A) { 935 int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); 936 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, 937 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 938 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); 939} 940 941void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 942 int incX, sp<Allocation> Ap) { 943 int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); 944 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, 945 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 946 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); 947} 948 949void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX, 950 sp<Allocation> Y, int incY, sp<Allocation> A) { 951 int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); 952 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, 953 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 954 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); 955} 956 957void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX, 958 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 959 int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); 960 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, 961 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 962 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); 963} 964 965 966/** 967 * Level 2, C and Z only 968 */ 969 970static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX, 971 sp<Allocation> Y, int incY, sp<Allocation> A) { 972 if (!A->getType()->getElement()->isCompatible(e) || 973 !X->getType()->getElement()->isCompatible(e) || 974 !Y->getType()->getElement()->isCompatible(e)) { 975 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 976 } 977 if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { 978 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); 979 } 980 981 int M = A->getType()->getY(); 982 int N = A->getType()->getX(); 983 if (incX <= 0 || incY <= 0) { 984 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); 985 } 986 int expectedXDim = 1 + (M - 1) * incX; 987 if ((int)X->getType()->getX() != expectedXDim) { 988 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 989 } 990 int expectedYDim = 1 + (N - 1) * incY; 991 if ((int)Y->getType()->getX() != expectedYDim) { 992 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); 993 } 994 995} 996 997void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A, 998 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 999 // HEMV is the same as SYR2 validation-wise 1000 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1001 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, 1002 0, 0, 0, Uplo, 0, 0, N, 0, 1003 alpha.x, alpha.y, A->getID(), X->getID(), 1004 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1005} 1006 1007void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A, 1008 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 1009 // HBMV is the same as SYR2 validation-wise 1010 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1011 if (K < 0) { 1012 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1013 } 1014 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, 1015 0, 0, 0, Uplo, 0, 0, N, K, 1016 alpha.x, alpha.y, A->getID(), X->getID(), 1017 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1018} 1019 1020void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap, 1021 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) { 1022 // HPMV is the same as SPR2 1023 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1024 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, 1025 0, 0, 0, Uplo, 0, 0, N, 0, 1026 alpha.x, alpha.y, Ap->getID(), X->getID(), 1027 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1028} 1029 1030void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX, 1031 sp<Allocation> Y, int incY, sp<Allocation> A) { 1032 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1033 int M = A->getType()->getY(); 1034 int N = A->getType()->getX(); 1035 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, 1036 0, 0, 0, 0, 0, M, N, 0, 1037 alpha.x, alpha.y, X->getID(), Y->getID(), 1038 0, 0, A->getID(), incX, incY, 0, 0); 1039} 1040 1041void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX, 1042 sp<Allocation> Y, int incY, sp<Allocation> A) { 1043 // Same as GERU 1044 validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); 1045 int M = A->getType()->getY(); 1046 int N = A->getType()->getX(); 1047 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, 1048 0, 0, 0, 0, 0, M, N, 0, 1049 alpha.x, alpha.y, X->getID(), Y->getID(), 1050 0, 0, A->getID(), incX, incY, 0, 0); 1051} 1052 1053void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 1054 int incX, sp<Allocation> A) { 1055 // Same as SYR 1056 int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); 1057 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, 1058 0, 0, 0, Uplo, 0, 0, N, 0, 1059 alpha, 0, X->getID(), 0, 1060 0, 0, A->getID(), incX, 0, 0, 0); 1061} 1062 1063void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X, 1064 int incX, sp<Allocation> Ap) { 1065 // Equivalent to SPR for validation 1066 int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); 1067 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, 1068 0, 0, 0, Uplo, 0, 0, N, 0, 1069 alpha, 0, X->getID(), 0, 1070 0, 0, Ap->getID(), incX, 0, 0, 0); 1071} 1072 1073void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX, 1074 sp<Allocation> Y, int incY, sp<Allocation> A) { 1075 // Same as SYR2 1076 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); 1077 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, 1078 0, 0, 0, Uplo, 0, 0, N, 0, 1079 alpha.x, alpha.y, X->getID(), Y->getID(), 1080 0, 0, A->getID(), incX, incY, 0, 0); 1081} 1082 1083void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX, 1084 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 1085 // Same as SPR2 1086 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); 1087 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, 1088 0, 0, 0, Uplo, 0, 0, N, 0, 1089 alpha.x, alpha.y, X->getID(), Y->getID(), 1090 0, 0, Ap->getID(), incX, incY, 0, 0); 1091} 1092 1093void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A, 1094 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) { 1095 // HEMV is the same as SYR2 validation-wise 1096 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1097 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, 1098 0, 0, 0, Uplo, 0, 0, N, 0, 1099 alpha.x, alpha.y, A->getID(), X->getID(), 1100 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1101} 1102 1103void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X, 1104 int incX, Double2 beta, sp<Allocation> Y, int incY) { 1105 // HBMV is the same as SYR2 validation-wise 1106 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1107 if (K < 0) { 1108 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); 1109 } 1110 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, 1111 0, 0, 0, Uplo, 0, 0, N, K, 1112 alpha.x, alpha.y, A->getID(), X->getID(), 1113 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1114} 1115 1116void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X, 1117 int incX, Double2 beta, sp<Allocation> Y, int incY) { 1118 // HPMV is the same as SPR2 1119 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1120 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, 1121 0, 0, 0, Uplo, 0, 0, N, 0, 1122 alpha.x, alpha.y, Ap->getID(), X->getID(), 1123 beta.x, beta.y, Y->getID(), incX, incY, 0, 0); 1124} 1125 1126void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX, 1127 sp<Allocation> Y, int incY, sp<Allocation> A) { 1128 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1129 int M = A->getType()->getY(); 1130 int N = A->getType()->getX(); 1131 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, 1132 0, 0, 0, 0, 0, M, N, 0, 1133 alpha.x, alpha.y, X->getID(), Y->getID(), 1134 0, 0, A->getID(), incX, incY, 0, 0); 1135} 1136 1137void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX, 1138 sp<Allocation> Y, int incY, sp<Allocation> A) { 1139 // Same as GERU 1140 validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); 1141 int M = A->getType()->getY(); 1142 int N = A->getType()->getX(); 1143 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, 1144 0, 0, 0, 0, 0, M, N, 0, 1145 alpha.x, alpha.y, X->getID(), Y->getID(), 1146 0, 0, A->getID(), incX, incY, 0, 0); 1147} 1148 1149void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 1150 int incX, sp<Allocation> A) { 1151 // Same as SYR 1152 int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); 1153 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, 1154 0, 0, 0, Uplo, 0, 0, N, 0, 1155 alpha, 0, X->getID(), 0, 1156 0, 0, A->getID(), incX, 0, 0, 0); 1157} 1158 1159void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X, 1160 int incX, sp<Allocation> Ap) { 1161 // Equivalent to SPR for validation 1162 int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); 1163 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, 1164 0, 0, 0, Uplo, 0, 0, N, 0, 1165 alpha, 0, X->getID(), 0, 1166 0, 0, Ap->getID(), incX, 0, 0, 0); 1167} 1168 1169void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX, 1170 sp<Allocation> Y, int incY, sp<Allocation> A) { 1171 // Same as SYR2 1172 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); 1173 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, 1174 0, 0, 0, Uplo, 0, 0, N, 0, 1175 alpha.x, alpha.y, X->getID(), Y->getID(), 1176 0, 0, A->getID(), incX, incY, 0, 0); 1177} 1178 1179void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX, 1180 sp<Allocation> Y, int incY, sp<Allocation> Ap) { 1181 // Same as SPR2 1182 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); 1183 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, 1184 0, 0, 0, Uplo, 0, 0, N, 0, 1185 alpha.x, alpha.y, X->getID(), Y->getID(), 1186 0, 0, Ap->getID(), incX, incY, 0, 0); 1187} 1188 1189 1190/** 1191 * Level 3 BLAS 1192 */ 1193 1194static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side, 1195 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1196 int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; 1197 if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || 1198 (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || 1199 (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { 1200 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1201 } 1202 if (C == nullptr) { 1203 // Since matrix C is used to store the result, it cannot be null. 1204 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); 1205 } 1206 cM = C->getType()->getY(); 1207 cN = C->getType()->getX(); 1208 1209 if (Side == RsBlasRight) { 1210 if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { 1211 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); 1212 } 1213 if (B != nullptr) { 1214 bM = A->getType()->getY(); 1215 bN = A->getType()->getX(); 1216 } 1217 if (A != nullptr) { 1218 aM = B->getType()->getY(); 1219 aN = B->getType()->getX(); 1220 } 1221 } else { 1222 if (A != nullptr) { 1223 if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { 1224 aN = A->getType()->getY(); 1225 aM = A->getType()->getX(); 1226 } else { 1227 aM = A->getType()->getY(); 1228 aN = A->getType()->getX(); 1229 } 1230 } 1231 if (B != nullptr) { 1232 if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { 1233 bN = B->getType()->getY(); 1234 bM = B->getType()->getX(); 1235 } else { 1236 bM = B->getType()->getY(); 1237 bN = B->getType()->getX(); 1238 } 1239 } 1240 } 1241 if (A != nullptr && B != nullptr && C != nullptr) { 1242 if (aN != bM || aM != cM || bN != cN) { 1243 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1244 } 1245 } else if (A != nullptr && C != nullptr) { 1246 // A and C only, for SYRK 1247 if (cM != cN) { 1248 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); 1249 } 1250 if (aM != cM) { 1251 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1252 } 1253 } else if (A != nullptr && B != nullptr) { 1254 // A and B only 1255 if (aN != bM) { 1256 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); 1257 } 1258 } 1259 1260} 1261 1262void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, 1263 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1264 validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); 1265 1266 int M = -1, N = -1, K = -1; 1267 if (TransA != RsBlasNoTrans) { 1268 M = A->getType()->getX(); 1269 K = A->getType()->getY(); 1270 } else { 1271 M = A->getType()->getY(); 1272 K = A->getType()->getX(); 1273 } 1274 if (TransB != RsBlasNoTrans) { 1275 N = B->getType()->getY(); 1276 } else { 1277 N = B->getType()->getX(); 1278 } 1279 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, 1280 TransA, TransB, 0, 0, 0, M, N, K, 1281 alpha, A->getID(), B->getID(), 1282 beta, C->getID(), 0, 0, 0, 0); 1283} 1284 1285void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, 1286 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1287 validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); 1288 int M = -1, N = -1, K = -1; 1289 if (TransA != RsBlasNoTrans) { 1290 M = A->getType()->getX(); 1291 K = A->getType()->getY(); 1292 } else { 1293 M = A->getType()->getY(); 1294 K = A->getType()->getX(); 1295 } 1296 if (TransB != RsBlasNoTrans) { 1297 N = B->getType()->getY(); 1298 } else { 1299 N = B->getType()->getX(); 1300 } 1301 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, 1302 TransA, TransB, 0, 0, 0, M, N, K, 1303 alpha, A->getID(), B->getID(), 1304 beta, C->getID(), 0, 0, 0, 0); 1305} 1306 1307void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, 1308 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1309 validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); 1310 int M = -1, N = -1, K = -1; 1311 if (TransA != RsBlasNoTrans) { 1312 M = A->getType()->getX(); 1313 K = A->getType()->getY(); 1314 } else { 1315 M = A->getType()->getY(); 1316 K = A->getType()->getX(); 1317 } 1318 if (TransB != RsBlasNoTrans) { 1319 N = B->getType()->getY(); 1320 } else { 1321 N = B->getType()->getX(); 1322 } 1323 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, 1324 TransA, TransB, 0, 0, 0, M, N, K, 1325 alpha.x, alpha.y, A->getID(), B->getID(), 1326 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1327} 1328 1329void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, 1330 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1331 validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); 1332 int M = -1, N = -1, K = -1; 1333 if (TransA != RsBlasNoTrans) { 1334 M = A->getType()->getX(); 1335 K = A->getType()->getY(); 1336 } else { 1337 M = A->getType()->getY(); 1338 K = A->getType()->getX(); 1339 } 1340 if (TransB != RsBlasNoTrans) { 1341 N = B->getType()->getY(); 1342 } else { 1343 N = B->getType()->getX(); 1344 } 1345 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, 1346 TransA, TransB, 0, 0, 0, M, N, K, 1347 alpha.x, alpha.y, A->getID(), B->getID(), 1348 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1349} 1350 1351void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, 1352 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1353 //For SYMM, Matrix A should be symmetric 1354 if (A->getType()->getX() != A->getType()->getY()) { 1355 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1356 } 1357 validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); 1358 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, 1359 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1360 alpha, A->getID(), B->getID(), 1361 beta, C->getID(), 0, 0, 0, 0); 1362} 1363 1364void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, 1365 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1366 if (A->getType()->getX() != A->getType()->getY()) { 1367 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1368 } 1369 validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); 1370 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, 1371 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1372 alpha, A->getID(), B->getID(), 1373 beta, C->getID(), 0, 0, 0, 0); 1374} 1375 1376void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1377 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1378 if (A->getType()->getX() != A->getType()->getY()) { 1379 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1380 } 1381 validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); 1382 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, 1383 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1384 alpha.x, alpha.y, A->getID(), B->getID(), 1385 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1386} 1387 1388void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1389 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1390 if (A->getType()->getX() != A->getType()->getY()) { 1391 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); 1392 } 1393 validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); 1394 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, 1395 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, 1396 alpha.x, alpha.y, A->getID(), B->getID(), 1397 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1398} 1399 1400void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1401 sp<Allocation> A, float beta, sp<Allocation> C) { 1402 validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); 1403 int K = -1; 1404 if (Trans != RsBlasNoTrans) { 1405 K = A->getType()->getY(); 1406 } else { 1407 K = A->getType()->getX(); 1408 } 1409 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, 1410 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1411 alpha, A->getID(), 0, 1412 beta, C->getID(), 0, 0, 0, 0); 1413} 1414 1415void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1416 sp<Allocation> A, double beta, sp<Allocation> C) { 1417 validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); 1418 int K = -1; 1419 if (Trans != RsBlasNoTrans) { 1420 K = A->getType()->getY(); 1421 } else { 1422 K = A->getType()->getX(); 1423 } 1424 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, 1425 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1426 alpha, A->getID(), 0, 1427 beta, C->getID(), 0, 0, 0, 0); 1428} 1429 1430void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1431 sp<Allocation> A, Float2 beta, sp<Allocation> C) { 1432 validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); 1433 int K = -1; 1434 if (Trans != RsBlasNoTrans) { 1435 K = A->getType()->getY(); 1436 } else { 1437 K = A->getType()->getX(); 1438 } 1439 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, 1440 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1441 alpha.x, alpha.y, A->getID(), 0, 1442 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1443} 1444 1445void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1446 sp<Allocation> A, Double2 beta, sp<Allocation> C) { 1447 validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); 1448 int K = -1; 1449 if (Trans != RsBlasNoTrans) { 1450 K = A->getType()->getY(); 1451 } else { 1452 K = A->getType()->getX(); 1453 } 1454 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, 1455 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1456 alpha.x, alpha.y, A->getID(), 0, 1457 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1458} 1459 1460static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1461 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1462 if (!A->getType()->getElement()->isCompatible(e) || 1463 !B->getType()->getElement()->isCompatible(e) || 1464 !C->getType()->getElement()->isCompatible(e)) { 1465 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1466 } 1467 int Cdim = -1; 1468 // A is n x k if no transpose, k x n if transpose 1469 // C is n x n 1470 if (Trans == RsBlasTrans) { 1471 // check columns versus C 1472 Cdim = A->getType()->getX(); 1473 } else { 1474 // check rows versus C 1475 Cdim = A->getType()->getY(); 1476 } 1477 if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { 1478 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); 1479 } 1480 // A dims == B dims 1481 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1482 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); 1483 } 1484} 1485 1486void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1487 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1488 validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); 1489 int K = -1; 1490 if (Trans != RsBlasNoTrans) { 1491 K = A->getType()->getY(); 1492 } else { 1493 K = A->getType()->getX(); 1494 } 1495 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, 1496 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1497 alpha, A->getID(), B->getID(), 1498 beta, C->getID(), 0, 0, 0, 0); 1499} 1500 1501void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1502 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1503 validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C); 1504 int K = -1; 1505 if (Trans != RsBlasNoTrans) { 1506 K = A->getType()->getY(); 1507 } else { 1508 K = A->getType()->getX(); 1509 } 1510 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k, 1511 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1512 alpha, A->getID(), B->getID(), 1513 beta, C->getID(), 0, 0, 0, 0); 1514} 1515 1516void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1517 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1518 validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1519 int K = -1; 1520 if (Trans != RsBlasNoTrans) { 1521 K = A->getType()->getY(); 1522 } else { 1523 K = A->getType()->getX(); 1524 } 1525 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k, 1526 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1527 alpha.x, alpha.y, A->getID(), B->getID(), 1528 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1529} 1530 1531void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1532 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1533 validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1534 int K = -1; 1535 if (Trans != RsBlasNoTrans) { 1536 K = A->getType()->getY(); 1537 } else { 1538 K = A->getType()->getX(); 1539 } 1540 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k, 1541 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, 1542 alpha.x, alpha.y, A->getID(), B->getID(), 1543 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1544} 1545 1546static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA, 1547 sp<Allocation> A, sp<Allocation> B) { 1548 int aM = -1, aN = -1, bM = -1, bN = -1; 1549 if (!A->getType()->getElement()->isCompatible(e) || 1550 !B->getType()->getElement()->isCompatible(e)) { 1551 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1552 } 1553 1554 aM = A->getType()->getY(); 1555 aN = A->getType()->getX(); 1556 if (aM != aN) { 1557 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A"); 1558 } 1559 1560 bM = B->getType()->getY(); 1561 bN = B->getType()->getX(); 1562 if (Side == RsBlasLeft) { 1563 if (aN != bM) { 1564 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1565 } 1566 } else { 1567 if (bN != aM) { 1568 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); 1569 } 1570 } 1571} 1572 1573void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1574 float alpha, sp<Allocation> A, sp<Allocation> B) { 1575 validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B); 1576 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm, 1577 TransA, 0, Side, Uplo, Diag,\ 1578 B->getType()->getY(), B->getType()->getX(), 0, 1579 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0); 1580} 1581 1582void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1583 double alpha, sp<Allocation> A, sp<Allocation> B) { 1584 validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B); 1585 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm, 1586 TransA, 0, Side, Uplo, Diag, 1587 B->getType()->getY(), B->getType()->getX(), 0, 1588 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1589} 1590 1591void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1592 Float2 alpha, sp<Allocation> A, sp<Allocation> B) { 1593 validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1594 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm, 1595 TransA, 0, Side, Uplo, Diag, 1596 B->getType()->getY(), B->getType()->getX(), 0, 1597 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1598} 1599 1600void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1601 Double2 alpha, sp<Allocation> A, sp<Allocation> B) { 1602 validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1603 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm, 1604 TransA, 0, Side, Uplo, Diag, 1605 B->getType()->getY(), B->getType()->getX(), 0, 1606 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1607} 1608 1609static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA, 1610 sp<Allocation> A, sp<Allocation> B) { 1611 int adim = -1, bM = -1, bN = -1; 1612 if (!A->getType()->getElement()->isCompatible(e) || 1613 !B->getType()->getElement()->isCompatible(e)) { 1614 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1615 } 1616 adim = A->getType()->getX(); 1617 if (adim != (int)A->getType()->getY()) { 1618 // This may be unnecessary, the restriction could potentially be relaxed. 1619 // Allocation A needs to contain at least that symmetric matrix but could theoretically 1620 // be larger for now we assume adapters are sufficient, will reevaluate in the future. 1621 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A"); 1622 } 1623 bM = B->getType()->getY(); 1624 bN = B->getType()->getX(); 1625 if (Side == RsBlasLeft) { 1626 // A is M*M 1627 if (adim != bM) { 1628 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1629 } 1630 } else { 1631 // A is N*N 1632 if (adim != bN) { 1633 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); 1634 } 1635 } 1636} 1637 1638void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1639 float alpha, sp<Allocation> A, sp<Allocation> B) { 1640 validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B); 1641 nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm, 1642 TransA, 0, Side, Uplo, Diag, 1643 B->getType()->getY(), B->getType()->getX(), 0, 1644 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1645} 1646 1647void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1648 double alpha, sp<Allocation> A, sp<Allocation> B) { 1649 validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B); 1650 nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm, 1651 TransA, 0, Side, Uplo, Diag, 1652 B->getType()->getY(), B->getType()->getX(), 0, 1653 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); 1654} 1655 1656void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1657 Float2 alpha, sp<Allocation> A, sp<Allocation> B) { 1658 validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B); 1659 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm, 1660 TransA, 0, Side, Uplo, Diag, 1661 B->getType()->getY(), B->getType()->getX(), 0, 1662 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1663} 1664 1665void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, 1666 Double2 alpha, sp<Allocation> A, sp<Allocation> B) { 1667 validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B); 1668 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm, 1669 TransA, 0, Side, Uplo, Diag, 1670 B->getType()->getY(), B->getType()->getX(), 0, 1671 alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); 1672} 1673 1674static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side, 1675 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1676 if (!A->getType()->getElement()->isCompatible(e) || 1677 !B->getType()->getElement()->isCompatible(e) || 1678 !C->getType()->getElement()->isCompatible(e)) { 1679 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1680 } 1681 1682 // A must be square; can potentially be relaxed similar to TRSM 1683 int adim = A->getType()->getX(); 1684 if (adim != (int)A->getType()->getY()) { 1685 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A"); 1686 } 1687 if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) || 1688 (Side == RsBlasRight && adim != (int)B->getType()->getX())) { 1689 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B"); 1690 } 1691 if (B->getType()->getX() != C->getType()->getX() || 1692 B->getType()->getY() != C->getType()->getY()) { 1693 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C"); 1694 } 1695} 1696 1697void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, 1698 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) { 1699 validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C); 1700 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm, 1701 0, 0, Side, Uplo, 0, 1702 C->getType()->getY(), C->getType()->getX(), 0, 1703 alpha.x, alpha.y, A->getID(), B->getID(), 1704 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1705} 1706 1707void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, 1708 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) { 1709 validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C); 1710 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm, 1711 0, 0, Side, Uplo, 0, 1712 C->getType()->getY(), C->getType()->getX(), 0, 1713 alpha.x, alpha.y, A->getID(), B->getID(), 1714 beta.x, beta.y, C->getID(), 0, 0, 0, 0); 1715} 1716 1717static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1718 sp<Allocation> A, sp<Allocation> C) { 1719 if (!A->getType()->getElement()->isCompatible(e) || 1720 !C->getType()->getElement()->isCompatible(e)) { 1721 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1722 } 1723 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1724 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1725 } 1726 int cdim = C->getType()->getX(); 1727 if (cdim != (int)C->getType()->getY()) { 1728 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C"); 1729 } 1730 if (Trans == RsBlasNoTrans) { 1731 if (cdim != (int)A->getType()->getY()) { 1732 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1733 } 1734 } else { 1735 if (cdim != (int)A->getType()->getX()) { 1736 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); 1737 } 1738 } 1739} 1740 1741void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, 1742 sp<Allocation> A, float beta, sp<Allocation> C) { 1743 validateHERK(mRS, Element::F32_2(mRS), Trans, A, C); 1744 int k = 0; 1745 if (Trans == RsBlasConjTrans) { 1746 k = A->getType()->getY(); 1747 } else { 1748 k = A->getType()->getX(); 1749 } 1750 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk, 1751 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1752 alpha, 0, A->getID(), 0, 1753 beta, 0, C->getID(), 0, 0, 0, 0); 1754} 1755 1756void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, 1757 sp<Allocation> A, double beta, sp<Allocation> C) { 1758 validateHERK(mRS, Element::F64_2(mRS), Trans, A, C); 1759 int k = 0; 1760 if (Trans == RsBlasConjTrans) { 1761 k = A->getType()->getY(); 1762 } else { 1763 k = A->getType()->getX(); 1764 } 1765 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk, 1766 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1767 alpha, 0, A->getID(), 0, 1768 beta, 0, C->getID(), 0, 0, 0, 0); 1769} 1770 1771static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans, 1772 sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) { 1773 if (!A->getType()->getElement()->isCompatible(e) || 1774 !B->getType()->getElement()->isCompatible(e) || 1775 !C->getType()->getElement()->isCompatible(e)) { 1776 mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); 1777 } 1778 if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { 1779 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); 1780 } 1781 int cdim = C->getType()->getX(); 1782 if (cdim != (int)C->getType()->getY()) { 1783 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C"); 1784 } 1785 if (Trans == RsBlasNoTrans) { 1786 if ((int)A->getType()->getY() != cdim) { 1787 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1788 } 1789 } else { 1790 if ((int)A->getType()->getX() != cdim) { 1791 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); 1792 } 1793 } 1794 if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { 1795 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices"); 1796 } 1797} 1798 1799void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, 1800 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) { 1801 validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C); 1802 int k = 0; 1803 if (Trans == RsBlasNoTrans) { 1804 k = A->getType()->getX(); 1805 } else { 1806 k = A->getType()->getY(); 1807 } 1808 nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k, 1809 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1810 alpha.x, alpha.y, A->getID(), B->getID(), 1811 beta, 0, C->getID(), 0, 0, 0, 0); 1812} 1813 1814void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, 1815 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) { 1816 validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C); 1817 int k = 0; 1818 if (Trans == RsBlasNoTrans) { 1819 k = A->getType()->getX(); 1820 } else { 1821 k = A->getType()->getY(); 1822 } 1823 nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k, 1824 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, 1825 alpha.x, alpha.y, A->getID(), B->getID(), 1826 beta, 0, C->getID(), 0, 0, 0, 0); 1827} 1828 1829 1830 1831void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset, 1832 sp<Allocation> C, int c_offset, int c_mult) { 1833 validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C); 1834 1835 if (a_offset < 0 || a_offset > 255) { 1836 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM"); 1837 } 1838 if (b_offset < 0 || b_offset > 255) { 1839 mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM"); 1840 } 1841 int M = -1, N = -1, K = -1; 1842 M = A->getType()->getY(); 1843 N = B->getType()->getY(); 1844 K = A->getType()->getX(); 1845 1846 nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset, 1847 B->getID(), b_offset, C->getID(), c_offset, c_mult); 1848} 1849