A short while ago, I needed to do some matrix exponentiation in R (raising a matrix to a power). By default, the exponentiation operator ^
in R, when applied to a matrix, will just raise each element of the matrix to a power, rather than the matrix itself:
> A < - matrix(c(1:4), nrow=2, byrow=T) > A
[,1] [,2]
[1,] 1 2
[2,] 3 4
> A ^ 2
[,1] [,2]
[1,] 1 4
[2,] 9 16
>
>
Whereas what we really want in this case is the equivalent of:
> A %*% A
[,1] [,2]
[1,] 7 10
[2,] 15 22
>
There are a few different ways of creating a matrix exponentiation operator in R: we could create an R function and create an exponentiation operator for matrices, similar to the %*%
matrix multiplication operator that exists already, or we could write the function in C and link to it. It seems logical to create an %^%
operator to match the current set of matrix operators. The choice of writing the exponentiation routine in R or C is actually less important than the exponentiation algorithm used. A fast exponentiation algorithm is the “square and multiply” method, an explanation of which can be found here.
I chose to write a basic exponentiation routine in C, by creating a new operator definition, and following these basic rules:
- returns the matrix inverse.
- returns the identity matrix.
- returns the original matrix.
- returns A to the nth power.
- Complex as well as real-valued matrices should be supported.
The code for this is shown below. Note that to calculate the matrix inverse, I don’t call directly into the underlying BLAS libraries, but actually call back into R’s solve()
function to calculate the inverse. The changes were made to array.c
in the R source tree under /src/main
.
SEXP do_matexp(SEXP call, SEXP op, SEXP args, SEXP rho)
{
int nrows, ncols;
SEXP matrix, tmp, dims, dims2;
SEXP x, y, x_, x__;
int i,j,e,mode;
// necessary?
mode = isComplex(CAR(args)) ? CPLXSXP : REALSXP;
SETCAR(args, coerceVector(CAR(args), mode));
x = CAR(args);
y = CADR(args);
dims = getAttrib(x, R_DimSymbol);
nrows = INTEGER(dims)[0];
ncols = INTEGER(dims)[1];
if (nrows != ncols)
error(_("can only raise square matrix to power"));
if (!isNumeric(y))
error(_("exponent must be a scalar integer"));
e = asInteger(y);
if (e < -1)
error(_("exponent must be >= -1"));
else if (e == 1)
return x;
else if (e == -1) { /* return matrix inverse via solve() */
SEXP p1, p2, inv;
PROTECT(p1 = p2 = allocList(2));
SET_TYPEOF(p1, LANGSXP);
CAR(p2) = install("solve.default");
p2 = CDR(p2);
CAR(p2) = x;
inv = eval(p1, rho);
UNPROTECT(1);
return inv;
}
PROTECT(matrix = allocVector(mode, nrows * ncols));
PROTECT(tmp = allocVector(mode, nrows * ncols));
PROTECT(x_ = allocVector(mode, nrows * ncols));
PROTECT(x__ = allocVector(mode, nrows * ncols));
if (mode == REALSXP)
Memcpy(REAL(x_), REAL(x), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(x_), COMPLEX(x), (size_t)nrows*ncols);
// Initialize matrix to identity matrix
// Set x[i * ncols + i] = 1
if (mode == REALSXP)
for (i = 0; i < ncols*nrows; i++)
REAL(matrix)[i] = ((i % (ncols+1) == 0) ? 1 : 0);
else
for (i = 0; i < ncols*nrows;i++) {
COMPLEX(matrix)[i].i = 0.0;
COMPLEX(matrix)[i].r = ((i % (ncols+1) == 0) ? 1.0 : 0.0);
}
if (e == 0) {
; // return identity matrix
}
else
while (e > 0) {
if (e & 1) {
if (mode == REALSXP)
matprod(REAL(matrix), nrows, ncols,
REAL(x_), nrows, ncols, REAL(tmp));
else
cmatprod(COMPLEX(matrix), nrows, ncols,
COMPLEX(x_), nrows, ncols, COMPLEX(tmp));
//copyMatrixData(tmp, matrix, nrows, ncols, mode);
if (mode == REALSXP)
Memcpy(REAL(matrix), REAL(tmp), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(matrix), COMPLEX(tmp), (size_t)nrows*ncols);
e--;
}
if (mode == REALSXP)
matprod(REAL(x_), nrows, ncols,
REAL(x_), nrows, ncols, REAL(x__));
else
cmatprod(COMPLEX(x_), nrows, ncols,
COMPLEX(x_), nrows, ncols, COMPLEX(x__));
//copyMatrixData(x__, x_, nrows, ncols, mode);
if (mode == REALSXP)
Memcpy(REAL(x_), REAL(x__), (size_t)nrows*ncols);
else
Memcpy(COMPLEX(x_), COMPLEX(x__), (size_t)nrows*ncols);
e >>= 1;
}
PROTECT(dims2 = allocVector(INTSXP, 2));
INTEGER(dims2)[0] = nrows;
INTEGER(dims2)[1] = ncols;
setAttrib(matrix, R_DimSymbol, dims2);
UNPROTECT(5);
return matrix;
}
To actually hook this routine up to the %^% operator, I needed a further piece of glue in /src/main/names.c
to associate the operator:
{"%^%", do_matexp, 3, 1, 2, {PP_BINARY, PREC_POWER, 0}}
Then to try it out (after recompiling R of course!):
> A < - matrix(c(1:4), nrow=2, byrow=TRUE) > A %^% 2
[,1] [,2]
[1,] 7 10
[2,] 15 22
>