Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Implementation of optimized einsum. 

3 

4""" 

5import itertools 

6import operator 

7 

8from numpy.core.multiarray import c_einsum 

9from numpy.core.numeric import asanyarray, tensordot 

10from numpy.core.overrides import array_function_dispatch 

11 

12__all__ = ['einsum', 'einsum_path'] 

13 

14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

15einsum_symbols_set = set(einsum_symbols) 

16 

17 

18def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 

19 """ 

20 Computes the number of FLOPS in the contraction. 

21 

22 Parameters 

23 ---------- 

24 idx_contraction : iterable 

25 The indices involved in the contraction 

26 inner : bool 

27 Does this contraction require an inner product? 

28 num_terms : int 

29 The number of terms in a contraction 

30 size_dictionary : dict 

31 The size of each of the indices in idx_contraction 

32 

33 Returns 

34 ------- 

35 flop_count : int 

36 The total number of FLOPS required for the contraction. 

37 

38 Examples 

39 -------- 

40 

41 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

42 30 

43 

44 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

45 60 

46 

47 """ 

48 

49 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 

50 op_factor = max(1, num_terms - 1) 

51 if inner: 

52 op_factor += 1 

53 

54 return overall_size * op_factor 

55 

56def _compute_size_by_dict(indices, idx_dict): 

57 """ 

58 Computes the product of the elements in indices based on the dictionary 

59 idx_dict. 

60 

61 Parameters 

62 ---------- 

63 indices : iterable 

64 Indices to base the product on. 

65 idx_dict : dictionary 

66 Dictionary of index sizes 

67 

68 Returns 

69 ------- 

70 ret : int 

71 The resulting product. 

72 

73 Examples 

74 -------- 

75 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

76 90 

77 

78 """ 

79 ret = 1 

80 for i in indices: 

81 ret *= idx_dict[i] 

82 return ret 

83 

84 

85def _find_contraction(positions, input_sets, output_set): 

86 """ 

87 Finds the contraction for a given set of input and output sets. 

88 

89 Parameters 

90 ---------- 

91 positions : iterable 

92 Integer positions of terms used in the contraction. 

93 input_sets : list 

94 List of sets that represent the lhs side of the einsum subscript 

95 output_set : set 

96 Set that represents the rhs side of the overall einsum subscript 

97 

98 Returns 

99 ------- 

100 new_result : set 

101 The indices of the resulting contraction 

102 remaining : list 

103 List of sets that have not been contracted, the new set is appended to 

104 the end of this list 

105 idx_removed : set 

106 Indices removed from the entire contraction 

107 idx_contraction : set 

108 The indices used in the current contraction 

109 

110 Examples 

111 -------- 

112 

113 # A simple dot product test case 

114 >>> pos = (0, 1) 

115 >>> isets = [set('ab'), set('bc')] 

116 >>> oset = set('ac') 

117 >>> _find_contraction(pos, isets, oset) 

118 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

119 

120 # A more complex case with additional terms in the contraction 

121 >>> pos = (0, 2) 

122 >>> isets = [set('abd'), set('ac'), set('bdc')] 

123 >>> oset = set('ac') 

124 >>> _find_contraction(pos, isets, oset) 

125 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

126 """ 

127 

128 idx_contract = set() 

129 idx_remain = output_set.copy() 

130 remaining = [] 

131 for ind, value in enumerate(input_sets): 

132 if ind in positions: 

133 idx_contract |= value 

134 else: 

135 remaining.append(value) 

136 idx_remain |= value 

137 

138 new_result = idx_remain & idx_contract 

139 idx_removed = (idx_contract - new_result) 

140 remaining.append(new_result) 

141 

142 return (new_result, remaining, idx_removed, idx_contract) 

143 

144 

145def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 

146 """ 

147 Computes all possible pair contractions, sieves the results based 

148 on ``memory_limit`` and returns the lowest cost path. This algorithm 

149 scales factorial with respect to the elements in the list ``input_sets``. 

150 

151 Parameters 

152 ---------- 

153 input_sets : list 

154 List of sets that represent the lhs side of the einsum subscript 

155 output_set : set 

156 Set that represents the rhs side of the overall einsum subscript 

157 idx_dict : dictionary 

158 Dictionary of index sizes 

159 memory_limit : int 

160 The maximum number of elements in a temporary array 

161 

162 Returns 

163 ------- 

164 path : list 

165 The optimal contraction order within the memory limit constraint. 

166 

167 Examples 

168 -------- 

169 >>> isets = [set('abd'), set('ac'), set('bdc')] 

170 >>> oset = set() 

171 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

172 >>> _optimal_path(isets, oset, idx_sizes, 5000) 

173 [(0, 2), (0, 1)] 

174 """ 

175 

176 full_results = [(0, [], input_sets)] 

177 for iteration in range(len(input_sets) - 1): 

178 iter_results = [] 

179 

180 # Compute all unique pairs 

181 for curr in full_results: 

182 cost, positions, remaining = curr 

183 for con in itertools.combinations(range(len(input_sets) - iteration), 2): 

184 

185 # Find the contraction 

186 cont = _find_contraction(con, remaining, output_set) 

187 new_result, new_input_sets, idx_removed, idx_contract = cont 

188 

189 # Sieve the results based on memory_limit 

190 new_size = _compute_size_by_dict(new_result, idx_dict) 

191 if new_size > memory_limit: 

192 continue 

193 

194 # Build (total_cost, positions, indices_remaining) 

195 total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict) 

196 new_pos = positions + [con] 

197 iter_results.append((total_cost, new_pos, new_input_sets)) 

198 

199 # Update combinatorial list, if we did not find anything return best 

200 # path + remaining contractions 

201 if iter_results: 

202 full_results = iter_results 

203 else: 

204 path = min(full_results, key=lambda x: x[0])[1] 

205 path += [tuple(range(len(input_sets) - iteration))] 

206 return path 

207 

208 # If we have not found anything return single einsum contraction 

209 if len(full_results) == 0: 

210 return [tuple(range(len(input_sets)))] 

211 

212 path = min(full_results, key=lambda x: x[0])[1] 

213 return path 

214 

215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): 

216 """Compute the cost (removed size + flops) and resultant indices for 

217 performing the contraction specified by ``positions``. 

218 

219 Parameters 

220 ---------- 

221 positions : tuple of int 

222 The locations of the proposed tensors to contract. 

223 input_sets : list of sets 

224 The indices found on each tensors. 

225 output_set : set 

226 The output indices of the expression. 

227 idx_dict : dict 

228 Mapping of each index to its size. 

229 memory_limit : int 

230 The total allowed size for an intermediary tensor. 

231 path_cost : int 

232 The contraction cost so far. 

233 naive_cost : int 

234 The cost of the unoptimized expression. 

235 

236 Returns 

237 ------- 

238 cost : (int, int) 

239 A tuple containing the size of any indices removed, and the flop cost. 

240 positions : tuple of int 

241 The locations of the proposed tensors to contract. 

242 new_input_sets : list of sets 

243 The resulting new list of indices if this proposed contraction is performed. 

244 

245 """ 

246 

247 # Find the contraction 

248 contract = _find_contraction(positions, input_sets, output_set) 

249 idx_result, new_input_sets, idx_removed, idx_contract = contract 

250 

251 # Sieve the results based on memory_limit 

252 new_size = _compute_size_by_dict(idx_result, idx_dict) 

253 if new_size > memory_limit: 

254 return None 

255 

256 # Build sort tuple 

257 old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions) 

258 removed_size = sum(old_sizes) - new_size 

259 

260 # NB: removed_size used to be just the size of any removed indices i.e.: 

261 # helpers.compute_size_by_dict(idx_removed, idx_dict) 

262 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 

263 sort = (-removed_size, cost) 

264 

265 # Sieve based on total cost as well 

266 if (path_cost + cost) > naive_cost: 

267 return None 

268 

269 # Add contraction to possible choices 

270 return [sort, positions, new_input_sets] 

271 

272 

273def _update_other_results(results, best): 

274 """Update the positions and provisional input_sets of ``results`` based on 

275 performing the contraction result ``best``. Remove any involving the tensors 

276 contracted. 

277 

278 Parameters 

279 ---------- 

280 results : list 

281 List of contraction results produced by ``_parse_possible_contraction``. 

282 best : list 

283 The best contraction of ``results`` i.e. the one that will be performed. 

284 

285 Returns 

286 ------- 

287 mod_results : list 

288 The list of modified results, updated with outcome of ``best`` contraction. 

289 """ 

290 

291 best_con = best[1] 

292 bx, by = best_con 

293 mod_results = [] 

294 

295 for cost, (x, y), con_sets in results: 

296 

297 # Ignore results involving tensors just contracted 

298 if x in best_con or y in best_con: 

299 continue 

300 

301 # Update the input_sets 

302 del con_sets[by - int(by > x) - int(by > y)] 

303 del con_sets[bx - int(bx > x) - int(bx > y)] 

304 con_sets.insert(-1, best[2][-1]) 

305 

306 # Update the position indices 

307 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 

308 mod_results.append((cost, mod_con, con_sets)) 

309 

310 return mod_results 

311 

312def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 

313 """ 

314 Finds the path by contracting the best pair until the input list is 

315 exhausted. The best pair is found by minimizing the tuple 

316 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 

317 matrix multiplication or inner product operations, then Hadamard like 

318 operations, and finally outer operations. Outer products are limited by 

319 ``memory_limit``. This algorithm scales cubically with respect to the 

320 number of elements in the list ``input_sets``. 

321 

322 Parameters 

323 ---------- 

324 input_sets : list 

325 List of sets that represent the lhs side of the einsum subscript 

326 output_set : set 

327 Set that represents the rhs side of the overall einsum subscript 

328 idx_dict : dictionary 

329 Dictionary of index sizes 

330 memory_limit_limit : int 

331 The maximum number of elements in a temporary array 

332 

333 Returns 

334 ------- 

335 path : list 

336 The greedy contraction order within the memory limit constraint. 

337 

338 Examples 

339 -------- 

340 >>> isets = [set('abd'), set('ac'), set('bdc')] 

341 >>> oset = set() 

342 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

343 >>> _greedy_path(isets, oset, idx_sizes, 5000) 

344 [(0, 2), (0, 1)] 

345 """ 

346 

347 # Handle trivial cases that leaked through 

348 if len(input_sets) == 1: 

349 return [(0,)] 

350 elif len(input_sets) == 2: 

351 return [(0, 1)] 

352 

353 # Build up a naive cost 

354 contract = _find_contraction(range(len(input_sets)), input_sets, output_set) 

355 idx_result, new_input_sets, idx_removed, idx_contract = contract 

356 naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) 

357 

358 # Initially iterate over all pairs 

359 comb_iter = itertools.combinations(range(len(input_sets)), 2) 

360 known_contractions = [] 

361 

362 path_cost = 0 

363 path = [] 

364 

365 for iteration in range(len(input_sets) - 1): 

366 

367 # Iterate over all pairs on first step, only previously found pairs on subsequent steps 

368 for positions in comb_iter: 

369 

370 # Always initially ignore outer products 

371 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 

372 continue 

373 

374 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, 

375 naive_cost) 

376 if result is not None: 

377 known_contractions.append(result) 

378 

379 # If we do not have a inner contraction, rescan pairs including outer products 

380 if len(known_contractions) == 0: 

381 

382 # Then check the outer products 

383 for positions in itertools.combinations(range(len(input_sets)), 2): 

384 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, 

385 path_cost, naive_cost) 

386 if result is not None: 

387 known_contractions.append(result) 

388 

389 # If we still did not find any remaining contractions, default back to einsum like behavior 

390 if len(known_contractions) == 0: 

391 path.append(tuple(range(len(input_sets)))) 

392 break 

393 

394 # Sort based on first index 

395 best = min(known_contractions, key=lambda x: x[0]) 

396 

397 # Now propagate as many unused contractions as possible to next iteration 

398 known_contractions = _update_other_results(known_contractions, best) 

399 

400 # Next iteration only compute contractions with the new tensor 

401 # All other contractions have been accounted for 

402 input_sets = best[2] 

403 new_tensor_pos = len(input_sets) - 1 

404 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 

405 

406 # Update path and total cost 

407 path.append(best[1]) 

408 path_cost += best[0][1] 

409 

410 return path 

411 

412 

413def _can_dot(inputs, result, idx_removed): 

414 """ 

415 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. 

416 

417 Parameters 

418 ---------- 

419 inputs : list of str 

420 Specifies the subscripts for summation. 

421 result : str 

422 Resulting summation. 

423 idx_removed : set 

424 Indices that are removed in the summation 

425 

426 

427 Returns 

428 ------- 

429 type : bool 

430 Returns true if BLAS should and can be used, else False 

431 

432 Notes 

433 ----- 

434 If the operations is BLAS level 1 or 2 and is not already aligned 

435 we default back to einsum as the memory movement to copy is more 

436 costly than the operation itself. 

437 

438 

439 Examples 

440 -------- 

441 

442 # Standard GEMM operation 

443 >>> _can_dot(['ij', 'jk'], 'ik', set('j')) 

444 True 

445 

446 # Can use the standard BLAS, but requires odd data movement 

447 >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) 

448 False 

449 

450 # DDOT where the memory is not aligned 

451 >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) 

452 False 

453 

454 """ 

455 

456 # All `dot` calls remove indices 

457 if len(idx_removed) == 0: 

458 return False 

459 

460 # BLAS can only handle two operands 

461 if len(inputs) != 2: 

462 return False 

463 

464 input_left, input_right = inputs 

465 

466 for c in set(input_left + input_right): 

467 # can't deal with repeated indices on same input or more than 2 total 

468 nl, nr = input_left.count(c), input_right.count(c) 

469 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

470 return False 

471 

472 # can't do implicit summation or dimension collapse e.g. 

473 # "ab,bc->c" (implicitly sum over 'a') 

474 # "ab,ca->ca" (take diagonal of 'a') 

475 if nl + nr - 1 == int(c in result): 

476 return False 

477 

478 # Build a few temporaries 

479 set_left = set(input_left) 

480 set_right = set(input_right) 

481 keep_left = set_left - idx_removed 

482 keep_right = set_right - idx_removed 

483 rs = len(idx_removed) 

484 

485 # At this point we are a DOT, GEMV, or GEMM operation 

486 

487 # Handle inner products 

488 

489 # DDOT with aligned data 

490 if input_left == input_right: 

491 return True 

492 

493 # DDOT without aligned data (better to use einsum) 

494 if set_left == set_right: 

495 return False 

496 

497 # Handle the 4 possible (aligned) GEMV or GEMM cases 

498 

499 # GEMM or GEMV no transpose 

500 if input_left[-rs:] == input_right[:rs]: 

501 return True 

502 

503 # GEMM or GEMV transpose both 

504 if input_left[:rs] == input_right[-rs:]: 

505 return True 

506 

507 # GEMM or GEMV transpose right 

508 if input_left[-rs:] == input_right[-rs:]: 

509 return True 

510 

511 # GEMM or GEMV transpose left 

512 if input_left[:rs] == input_right[:rs]: 

513 return True 

514 

515 # Einsum is faster than GEMV if we have to copy data 

516 if not keep_left or not keep_right: 

517 return False 

518 

519 # We are a matrix-matrix product, but we need to copy data 

520 return True 

521 

522 

523def _parse_einsum_input(operands): 

524 """ 

525 A reproduction of einsum c side einsum parsing in python. 

526 

527 Returns 

528 ------- 

529 input_strings : str 

530 Parsed input strings 

531 output_string : str 

532 Parsed output string 

533 operands : list of array_like 

534 The operands to use in the numpy contraction 

535 

536 Examples 

537 -------- 

538 The operand list is simplified to reduce printing: 

539 

540 >>> np.random.seed(123) 

541 >>> a = np.random.rand(4, 4) 

542 >>> b = np.random.rand(4, 4, 4) 

543 >>> _parse_einsum_input(('...a,...a->...', a, b)) 

544 ('za,xza', 'xz', [a, b]) # may vary 

545 

546 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

547 ('za,xza', 'xz', [a, b]) # may vary 

548 """ 

549 

550 if len(operands) == 0: 

551 raise ValueError("No input operands") 

552 

553 if isinstance(operands[0], str): 

554 subscripts = operands[0].replace(" ", "") 

555 operands = [asanyarray(v) for v in operands[1:]] 

556 

557 # Ensure all characters are valid 

558 for s in subscripts: 

559 if s in '.,->': 

560 continue 

561 if s not in einsum_symbols: 

562 raise ValueError("Character %s is not a valid symbol." % s) 

563 

564 else: 

565 tmp_operands = list(operands) 

566 operand_list = [] 

567 subscript_list = [] 

568 for p in range(len(operands) // 2): 

569 operand_list.append(tmp_operands.pop(0)) 

570 subscript_list.append(tmp_operands.pop(0)) 

571 

572 output_list = tmp_operands[-1] if len(tmp_operands) else None 

573 operands = [asanyarray(v) for v in operand_list] 

574 subscripts = "" 

575 last = len(subscript_list) - 1 

576 for num, sub in enumerate(subscript_list): 

577 for s in sub: 

578 if s is Ellipsis: 

579 subscripts += "..." 

580 else: 

581 try: 

582 s = operator.index(s) 

583 except TypeError as e: 

584 raise TypeError("For this input type lists must contain " 

585 "either int or Ellipsis") from e 

586 subscripts += einsum_symbols[s] 

587 if num != last: 

588 subscripts += "," 

589 

590 if output_list is not None: 

591 subscripts += "->" 

592 for s in output_list: 

593 if s is Ellipsis: 

594 subscripts += "..." 

595 else: 

596 try: 

597 s = operator.index(s) 

598 except TypeError as e: 

599 raise TypeError("For this input type lists must contain " 

600 "either int or Ellipsis") from e 

601 subscripts += einsum_symbols[s] 

602 # Check for proper "->" 

603 if ("-" in subscripts) or (">" in subscripts): 

604 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

605 if invalid or (subscripts.count("->") != 1): 

606 raise ValueError("Subscripts can only contain one '->'.") 

607 

608 # Parse ellipses 

609 if "." in subscripts: 

610 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

611 unused = list(einsum_symbols_set - set(used)) 

612 ellipse_inds = "".join(unused) 

613 longest = 0 

614 

615 if "->" in subscripts: 

616 input_tmp, output_sub = subscripts.split("->") 

617 split_subscripts = input_tmp.split(",") 

618 out_sub = True 

619 else: 

620 split_subscripts = subscripts.split(',') 

621 out_sub = False 

622 

623 for num, sub in enumerate(split_subscripts): 

624 if "." in sub: 

625 if (sub.count(".") != 3) or (sub.count("...") != 1): 

626 raise ValueError("Invalid Ellipses.") 

627 

628 # Take into account numerical values 

629 if operands[num].shape == (): 

630 ellipse_count = 0 

631 else: 

632 ellipse_count = max(operands[num].ndim, 1) 

633 ellipse_count -= (len(sub) - 3) 

634 

635 if ellipse_count > longest: 

636 longest = ellipse_count 

637 

638 if ellipse_count < 0: 

639 raise ValueError("Ellipses lengths do not match.") 

640 elif ellipse_count == 0: 

641 split_subscripts[num] = sub.replace('...', '') 

642 else: 

643 rep_inds = ellipse_inds[-ellipse_count:] 

644 split_subscripts[num] = sub.replace('...', rep_inds) 

645 

646 subscripts = ",".join(split_subscripts) 

647 if longest == 0: 

648 out_ellipse = "" 

649 else: 

650 out_ellipse = ellipse_inds[-longest:] 

651 

652 if out_sub: 

653 subscripts += "->" + output_sub.replace("...", out_ellipse) 

654 else: 

655 # Special care for outputless ellipses 

656 output_subscript = "" 

657 tmp_subscripts = subscripts.replace(",", "") 

658 for s in sorted(set(tmp_subscripts)): 

659 if s not in (einsum_symbols): 

660 raise ValueError("Character %s is not a valid symbol." % s) 

661 if tmp_subscripts.count(s) == 1: 

662 output_subscript += s 

663 normal_inds = ''.join(sorted(set(output_subscript) - 

664 set(out_ellipse))) 

665 

666 subscripts += "->" + out_ellipse + normal_inds 

667 

668 # Build output string if does not exist 

669 if "->" in subscripts: 

670 input_subscripts, output_subscript = subscripts.split("->") 

671 else: 

672 input_subscripts = subscripts 

673 # Build output subscripts 

674 tmp_subscripts = subscripts.replace(",", "") 

675 output_subscript = "" 

676 for s in sorted(set(tmp_subscripts)): 

677 if s not in einsum_symbols: 

678 raise ValueError("Character %s is not a valid symbol." % s) 

679 if tmp_subscripts.count(s) == 1: 

680 output_subscript += s 

681 

682 # Make sure output subscripts are in the input 

683 for char in output_subscript: 

684 if char not in input_subscripts: 

685 raise ValueError("Output character %s did not appear in the input" 

686 % char) 

687 

688 # Make sure number operands is equivalent to the number of terms 

689 if len(input_subscripts.split(',')) != len(operands): 

690 raise ValueError("Number of einsum subscripts must be equal to the " 

691 "number of operands.") 

692 

693 return (input_subscripts, output_subscript, operands) 

694 

695 

696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 

697 # NOTE: technically, we should only dispatch on array-like arguments, not 

698 # subscripts (given as strings). But separating operands into 

699 # arrays/subscripts is a little tricky/slow (given einsum's two supported 

700 # signatures), so as a practical shortcut we dispatch on everything. 

701 # Strings will be ignored for dispatching since they don't define 

702 # __array_function__. 

703 return operands 

704 

705 

706@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 

707def einsum_path(*operands, optimize='greedy', einsum_call=False): 

708 """ 

709 einsum_path(subscripts, *operands, optimize='greedy') 

710 

711 Evaluates the lowest cost contraction order for an einsum expression by 

712 considering the creation of intermediate arrays. 

713 

714 Parameters 

715 ---------- 

716 subscripts : str 

717 Specifies the subscripts for summation. 

718 *operands : list of array_like 

719 These are the arrays for the operation. 

720 optimize : {bool, list, tuple, 'greedy', 'optimal'} 

721 Choose the type of path. If a tuple is provided, the second argument is 

722 assumed to be the maximum intermediate size created. If only a single 

723 argument is provided the largest input or output array size is used 

724 as a maximum intermediate size. 

725 

726 * if a list is given that starts with ``einsum_path``, uses this as the 

727 contraction path 

728 * if False no optimization is taken 

729 * if True defaults to the 'greedy' algorithm 

730 * 'optimal' An algorithm that combinatorially explores all possible 

731 ways of contracting the listed tensors and choosest the least costly 

732 path. Scales exponentially with the number of terms in the 

733 contraction. 

734 * 'greedy' An algorithm that chooses the best pair contraction 

735 at each step. Effectively, this algorithm searches the largest inner, 

736 Hadamard, and then outer products at each step. Scales cubically with 

737 the number of terms in the contraction. Equivalent to the 'optimal' 

738 path for most contractions. 

739 

740 Default is 'greedy'. 

741 

742 Returns 

743 ------- 

744 path : list of tuples 

745 A list representation of the einsum path. 

746 string_repr : str 

747 A printable representation of the einsum path. 

748 

749 Notes 

750 ----- 

751 The resulting path indicates which terms of the input contraction should be 

752 contracted first, the result of this contraction is then appended to the 

753 end of the contraction list. This list can then be iterated over until all 

754 intermediate contractions are complete. 

755 

756 See Also 

757 -------- 

758 einsum, linalg.multi_dot 

759 

760 Examples 

761 -------- 

762 

763 We can begin with a chain dot example. In this case, it is optimal to 

764 contract the ``b`` and ``c`` tensors first as represented by the first 

765 element of the path ``(1, 2)``. The resulting tensor is added to the end 

766 of the contraction and the remaining contraction ``(0, 1)`` is then 

767 completed. 

768 

769 >>> np.random.seed(123) 

770 >>> a = np.random.rand(2, 2) 

771 >>> b = np.random.rand(2, 5) 

772 >>> c = np.random.rand(5, 2) 

773 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 

774 >>> print(path_info[0]) 

775 ['einsum_path', (1, 2), (0, 1)] 

776 >>> print(path_info[1]) 

777 Complete contraction: ij,jk,kl->il # may vary 

778 Naive scaling: 4 

779 Optimized scaling: 3 

780 Naive FLOP count: 1.600e+02 

781 Optimized FLOP count: 5.600e+01 

782 Theoretical speedup: 2.857 

783 Largest intermediate: 4.000e+00 elements 

784 ------------------------------------------------------------------------- 

785 scaling current remaining 

786 ------------------------------------------------------------------------- 

787 3 kl,jk->jl ij,jl->il 

788 3 jl,ij->il il->il 

789 

790 

791 A more complex index transformation example. 

792 

793 >>> I = np.random.rand(10, 10, 10, 10) 

794 >>> C = np.random.rand(10, 10) 

795 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 

796 ... optimize='greedy') 

797 

798 >>> print(path_info[0]) 

799 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 

800 >>> print(path_info[1])  

801 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 

802 Naive scaling: 8 

803 Optimized scaling: 5 

804 Naive FLOP count: 8.000e+08 

805 Optimized FLOP count: 8.000e+05 

806 Theoretical speedup: 1000.000 

807 Largest intermediate: 1.000e+04 elements 

808 -------------------------------------------------------------------------- 

809 scaling current remaining 

810 -------------------------------------------------------------------------- 

811 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

812 5 bcde,fb->cdef gc,hd,cdef->efgh 

813 5 cdef,gc->defg hd,defg->efgh 

814 5 defg,hd->efgh efgh->efgh 

815 """ 

816 

817 # Figure out what the path really is 

818 path_type = optimize 

819 if path_type is True: 

820 path_type = 'greedy' 

821 if path_type is None: 

822 path_type = False 

823 

824 memory_limit = None 

825 

826 # No optimization or a named path algorithm 

827 if (path_type is False) or isinstance(path_type, str): 

828 pass 

829 

830 # Given an explicit path 

831 elif len(path_type) and (path_type[0] == 'einsum_path'): 

832 pass 

833 

834 # Path tuple with memory limit 

835 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 

836 isinstance(path_type[1], (int, float))): 

837 memory_limit = int(path_type[1]) 

838 path_type = path_type[0] 

839 

840 else: 

841 raise TypeError("Did not understand the path: %s" % str(path_type)) 

842 

843 # Hidden option, only einsum should call this 

844 einsum_call_arg = einsum_call 

845 

846 # Python side parsing 

847 input_subscripts, output_subscript, operands = _parse_einsum_input(operands) 

848 

849 # Build a few useful list and sets 

850 input_list = input_subscripts.split(',') 

851 input_sets = [set(x) for x in input_list] 

852 output_set = set(output_subscript) 

853 indices = set(input_subscripts.replace(',', '')) 

854 

855 # Get length of each unique dimension and ensure all dimensions are correct 

856 dimension_dict = {} 

857 broadcast_indices = [[] for x in range(len(input_list))] 

858 for tnum, term in enumerate(input_list): 

859 sh = operands[tnum].shape 

860 if len(sh) != len(term): 

861 raise ValueError("Einstein sum subscript %s does not contain the " 

862 "correct number of indices for operand %d." 

863 % (input_subscripts[tnum], tnum)) 

864 for cnum, char in enumerate(term): 

865 dim = sh[cnum] 

866 

867 # Build out broadcast indices 

868 if dim == 1: 

869 broadcast_indices[tnum].append(char) 

870 

871 if char in dimension_dict.keys(): 

872 # For broadcasting cases we always want the largest dim size 

873 if dimension_dict[char] == 1: 

874 dimension_dict[char] = dim 

875 elif dim not in (1, dimension_dict[char]): 

876 raise ValueError("Size of label '%s' for operand %d (%d) " 

877 "does not match previous terms (%d)." 

878 % (char, tnum, dimension_dict[char], dim)) 

879 else: 

880 dimension_dict[char] = dim 

881 

882 # Convert broadcast inds to sets 

883 broadcast_indices = [set(x) for x in broadcast_indices] 

884 

885 # Compute size of each input array plus the output array 

886 size_list = [_compute_size_by_dict(term, dimension_dict) 

887 for term in input_list + [output_subscript]] 

888 max_size = max(size_list) 

889 

890 if memory_limit is None: 

891 memory_arg = max_size 

892 else: 

893 memory_arg = memory_limit 

894 

895 # Compute naive cost 

896 # This isn't quite right, need to look into exactly how einsum does this 

897 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

898 naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) 

899 

900 # Compute the path 

901 if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set): 

902 # Nothing to be optimized, leave it to einsum 

903 path = [tuple(range(len(input_list)))] 

904 elif path_type == "greedy": 

905 path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) 

906 elif path_type == "optimal": 

907 path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) 

908 elif path_type[0] == 'einsum_path': 

909 path = path_type[1:] 

910 else: 

911 raise KeyError("Path name %s not found", path_type) 

912 

913 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 

914 

915 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

916 for cnum, contract_inds in enumerate(path): 

917 # Make sure we remove inds from right to left 

918 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 

919 

920 contract = _find_contraction(contract_inds, input_sets, output_set) 

921 out_inds, input_sets, idx_removed, idx_contract = contract 

922 

923 cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict) 

924 cost_list.append(cost) 

925 scale_list.append(len(idx_contract)) 

926 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 

927 

928 bcast = set() 

929 tmp_inputs = [] 

930 for x in contract_inds: 

931 tmp_inputs.append(input_list.pop(x)) 

932 bcast |= broadcast_indices.pop(x) 

933 

934 new_bcast_inds = bcast - idx_removed 

935 

936 # If we're broadcasting, nix blas 

937 if not len(idx_removed & bcast): 

938 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) 

939 else: 

940 do_blas = False 

941 

942 # Last contraction 

943 if (cnum - len(path)) == -1: 

944 idx_result = output_subscript 

945 else: 

946 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 

947 idx_result = "".join([x[1] for x in sorted(sort_result)]) 

948 

949 input_list.append(idx_result) 

950 broadcast_indices.append(new_bcast_inds) 

951 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

952 

953 contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas) 

954 contraction_list.append(contraction) 

955 

956 opt_cost = sum(cost_list) + 1 

957 

958 if einsum_call_arg: 

959 return (operands, contraction_list) 

960 

961 # Return the path along with a nice string representation 

962 overall_contraction = input_subscripts + "->" + output_subscript 

963 header = ("scaling", "current", "remaining") 

964 

965 speedup = naive_cost / opt_cost 

966 max_i = max(size_list) 

967 

968 path_print = " Complete contraction: %s\n" % overall_contraction 

969 path_print += " Naive scaling: %d\n" % len(indices) 

970 path_print += " Optimized scaling: %d\n" % max(scale_list) 

971 path_print += " Naive FLOP count: %.3e\n" % naive_cost 

972 path_print += " Optimized FLOP count: %.3e\n" % opt_cost 

973 path_print += " Theoretical speedup: %3.3f\n" % speedup 

974 path_print += " Largest intermediate: %.3e elements\n" % max_i 

975 path_print += "-" * 74 + "\n" 

976 path_print += "%6s %24s %40s\n" % header 

977 path_print += "-" * 74 

978 

979 for n, contraction in enumerate(contraction_list): 

980 inds, idx_rm, einsum_str, remaining, blas = contraction 

981 remaining_str = ",".join(remaining) + "->" + output_subscript 

982 path_run = (scale_list[n], einsum_str, remaining_str) 

983 path_print += "\n%4d %24s %40s" % path_run 

984 

985 path = ['einsum_path'] + path 

986 return (path, path_print) 

987 

988 

989def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 

990 # Arguably we dispatch on more arguments that we really should; see note in 

991 # _einsum_path_dispatcher for why. 

992 yield from operands 

993 yield out 

994 

995 

996# Rewrite einsum to handle different cases 

997@array_function_dispatch(_einsum_dispatcher, module='numpy') 

998def einsum(*operands, out=None, optimize=False, **kwargs): 

999 """ 

1000 einsum(subscripts, *operands, out=None, dtype=None, order='K', 

1001 casting='safe', optimize=False) 

1002 

1003 Evaluates the Einstein summation convention on the operands. 

1004 

1005 Using the Einstein summation convention, many common multi-dimensional, 

1006 linear algebraic array operations can be represented in a simple fashion. 

1007 In *implicit* mode `einsum` computes these values. 

1008 

1009 In *explicit* mode, `einsum` provides further flexibility to compute 

1010 other array operations that might not be considered classical Einstein 

1011 summation operations, by disabling, or forcing summation over specified 

1012 subscript labels. 

1013 

1014 See the notes and examples for clarification. 

1015 

1016 Parameters 

1017 ---------- 

1018 subscripts : str 

1019 Specifies the subscripts for summation as comma separated list of 

1020 subscript labels. An implicit (classical Einstein summation) 

1021 calculation is performed unless the explicit indicator '->' is 

1022 included as well as subscript labels of the precise output form. 

1023 operands : list of array_like 

1024 These are the arrays for the operation. 

1025 out : ndarray, optional 

1026 If provided, the calculation is done into this array. 

1027 dtype : {data-type, None}, optional 

1028 If provided, forces the calculation to use the data type specified. 

1029 Note that you may have to also give a more liberal `casting` 

1030 parameter to allow the conversions. Default is None. 

1031 order : {'C', 'F', 'A', 'K'}, optional 

1032 Controls the memory layout of the output. 'C' means it should 

1033 be C contiguous. 'F' means it should be Fortran contiguous, 

1034 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 

1035 'K' means it should be as close to the layout as the inputs as 

1036 is possible, including arbitrarily permuted axes. 

1037 Default is 'K'. 

1038 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 

1039 Controls what kind of data casting may occur. Setting this to 

1040 'unsafe' is not recommended, as it can adversely affect accumulations. 

1041 

1042 * 'no' means the data types should not be cast at all. 

1043 * 'equiv' means only byte-order changes are allowed. 

1044 * 'safe' means only casts which can preserve values are allowed. 

1045 * 'same_kind' means only safe casts or casts within a kind, 

1046 like float64 to float32, are allowed. 

1047 * 'unsafe' means any data conversions may be done. 

1048 

1049 Default is 'safe'. 

1050 optimize : {False, True, 'greedy', 'optimal'}, optional 

1051 Controls if intermediate optimization should occur. No optimization 

1052 will occur if False and True will default to the 'greedy' algorithm. 

1053 Also accepts an explicit contraction list from the ``np.einsum_path`` 

1054 function. See ``np.einsum_path`` for more details. Defaults to False. 

1055 

1056 Returns 

1057 ------- 

1058 output : ndarray 

1059 The calculation based on the Einstein summation convention. 

1060 

1061 See Also 

1062 -------- 

1063 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 

1064 

1065 Notes 

1066 ----- 

1067 .. versionadded:: 1.6.0 

1068 

1069 The Einstein summation convention can be used to compute 

1070 many multi-dimensional, linear algebraic array operations. `einsum` 

1071 provides a succinct way of representing these. 

1072 

1073 A non-exhaustive list of these operations, 

1074 which can be computed by `einsum`, is shown below along with examples: 

1075 

1076 * Trace of an array, :py:func:`numpy.trace`. 

1077 * Return a diagonal, :py:func:`numpy.diag`. 

1078 * Array axis summations, :py:func:`numpy.sum`. 

1079 * Transpositions and permutations, :py:func:`numpy.transpose`. 

1080 * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`. 

1081 * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`. 

1082 * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`. 

1083 * Tensor contractions, :py:func:`numpy.tensordot`. 

1084 * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`. 

1085 

1086 The subscripts string is a comma-separated list of subscript labels, 

1087 where each label refers to a dimension of the corresponding operand. 

1088 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 

1089 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 

1090 appears only once, it is not summed, so ``np.einsum('i', a)`` produces a 

1091 view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` 

1092 describes traditional matrix multiplication and is equivalent to 

1093 :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one 

1094 operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent 

1095 to :py:func:`np.trace(a) <numpy.trace>`. 

1096 

1097 In *implicit mode*, the chosen subscripts are important 

1098 since the axes of the output are reordered alphabetically. This 

1099 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 

1100 ``np.einsum('ji', a)`` takes its transpose. Additionally, 

1101 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 

1102 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 

1103 multiplication since subscript 'h' precedes subscript 'i'. 

1104 

1105 In *explicit mode* the output can be directly controlled by 

1106 specifying output subscript labels. This requires the 

1107 identifier '->' as well as the list of output subscript labels. 

1108 This feature increases the flexibility of the function since 

1109 summing can be disabled or forced when required. The call 

1110 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`, 

1111 and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`. 

1112 The difference is that `einsum` does not allow broadcasting by default. 

1113 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 

1114 order of the output subscript labels and therefore returns matrix 

1115 multiplication, unlike the example above in implicit mode. 

1116 

1117 To enable and control broadcasting, use an ellipsis. Default 

1118 NumPy-style broadcasting is done by adding an ellipsis 

1119 to the left of each term, like ``np.einsum('...ii->...i', a)``. 

1120 To take the trace along the first and last axes, 

1121 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 

1122 product with the left-most indices instead of rightmost, one can do 

1123 ``np.einsum('ij...,jk...->ik...', a, b)``. 

1124 

1125 When there is only one operand, no axes are summed, and no output 

1126 parameter is provided, a view into the operand is returned instead 

1127 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 

1128 produces a view (changed in version 1.10.0). 

1129 

1130 `einsum` also provides an alternative way to provide the subscripts 

1131 and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 

1132 If the output shape is not provided in this format `einsum` will be 

1133 calculated in implicit mode, otherwise it will be performed explicitly. 

1134 The examples below have corresponding `einsum` calls with the two 

1135 parameter methods. 

1136 

1137 .. versionadded:: 1.10.0 

1138 

1139 Views returned from einsum are now writeable whenever the input array 

1140 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 

1141 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 

1142 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 

1143 of a 2D array. 

1144 

1145 .. versionadded:: 1.12.0 

1146 

1147 Added the ``optimize`` argument which will optimize the contraction order 

1148 of an einsum expression. For a contraction with three or more operands this 

1149 can greatly increase the computational efficiency at the cost of a larger 

1150 memory footprint during computation. 

1151 

1152 Typically a 'greedy' algorithm is applied which empirical tests have shown 

1153 returns the optimal path in the majority of cases. In some cases 'optimal' 

1154 will return the superlative path through a more expensive, exhaustive search. 

1155 For iterative calculations it may be advisable to calculate the optimal path 

1156 once and reuse that path by supplying it as an argument. An example is given 

1157 below. 

1158 

1159 See :py:func:`numpy.einsum_path` for more details. 

1160 

1161 Examples 

1162 -------- 

1163 >>> a = np.arange(25).reshape(5,5) 

1164 >>> b = np.arange(5) 

1165 >>> c = np.arange(6).reshape(2,3) 

1166 

1167 Trace of a matrix: 

1168 

1169 >>> np.einsum('ii', a) 

1170 60 

1171 >>> np.einsum(a, [0,0]) 

1172 60 

1173 >>> np.trace(a) 

1174 60 

1175 

1176 Extract the diagonal (requires explicit form): 

1177 

1178 >>> np.einsum('ii->i', a) 

1179 array([ 0, 6, 12, 18, 24]) 

1180 >>> np.einsum(a, [0,0], [0]) 

1181 array([ 0, 6, 12, 18, 24]) 

1182 >>> np.diag(a) 

1183 array([ 0, 6, 12, 18, 24]) 

1184 

1185 Sum over an axis (requires explicit form): 

1186 

1187 >>> np.einsum('ij->i', a) 

1188 array([ 10, 35, 60, 85, 110]) 

1189 >>> np.einsum(a, [0,1], [0]) 

1190 array([ 10, 35, 60, 85, 110]) 

1191 >>> np.sum(a, axis=1) 

1192 array([ 10, 35, 60, 85, 110]) 

1193 

1194 For higher dimensional arrays summing a single axis can be done with ellipsis: 

1195 

1196 >>> np.einsum('...j->...', a) 

1197 array([ 10, 35, 60, 85, 110]) 

1198 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 

1199 array([ 10, 35, 60, 85, 110]) 

1200 

1201 Compute a matrix transpose, or reorder any number of axes: 

1202 

1203 >>> np.einsum('ji', c) 

1204 array([[0, 3], 

1205 [1, 4], 

1206 [2, 5]]) 

1207 >>> np.einsum('ij->ji', c) 

1208 array([[0, 3], 

1209 [1, 4], 

1210 [2, 5]]) 

1211 >>> np.einsum(c, [1,0]) 

1212 array([[0, 3], 

1213 [1, 4], 

1214 [2, 5]]) 

1215 >>> np.transpose(c) 

1216 array([[0, 3], 

1217 [1, 4], 

1218 [2, 5]]) 

1219 

1220 Vector inner products: 

1221 

1222 >>> np.einsum('i,i', b, b) 

1223 30 

1224 >>> np.einsum(b, [0], b, [0]) 

1225 30 

1226 >>> np.inner(b,b) 

1227 30 

1228 

1229 Matrix vector multiplication: 

1230 

1231 >>> np.einsum('ij,j', a, b) 

1232 array([ 30, 80, 130, 180, 230]) 

1233 >>> np.einsum(a, [0,1], b, [1]) 

1234 array([ 30, 80, 130, 180, 230]) 

1235 >>> np.dot(a, b) 

1236 array([ 30, 80, 130, 180, 230]) 

1237 >>> np.einsum('...j,j', a, b) 

1238 array([ 30, 80, 130, 180, 230]) 

1239 

1240 Broadcasting and scalar multiplication: 

1241 

1242 >>> np.einsum('..., ...', 3, c) 

1243 array([[ 0, 3, 6], 

1244 [ 9, 12, 15]]) 

1245 >>> np.einsum(',ij', 3, c) 

1246 array([[ 0, 3, 6], 

1247 [ 9, 12, 15]]) 

1248 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 

1249 array([[ 0, 3, 6], 

1250 [ 9, 12, 15]]) 

1251 >>> np.multiply(3, c) 

1252 array([[ 0, 3, 6], 

1253 [ 9, 12, 15]]) 

1254 

1255 Vector outer product: 

1256 

1257 >>> np.einsum('i,j', np.arange(2)+1, b) 

1258 array([[0, 1, 2, 3, 4], 

1259 [0, 2, 4, 6, 8]]) 

1260 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 

1261 array([[0, 1, 2, 3, 4], 

1262 [0, 2, 4, 6, 8]]) 

1263 >>> np.outer(np.arange(2)+1, b) 

1264 array([[0, 1, 2, 3, 4], 

1265 [0, 2, 4, 6, 8]]) 

1266 

1267 Tensor contraction: 

1268 

1269 >>> a = np.arange(60.).reshape(3,4,5) 

1270 >>> b = np.arange(24.).reshape(4,3,2) 

1271 >>> np.einsum('ijk,jil->kl', a, b) 

1272 array([[4400., 4730.], 

1273 [4532., 4874.], 

1274 [4664., 5018.], 

1275 [4796., 5162.], 

1276 [4928., 5306.]]) 

1277 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 

1278 array([[4400., 4730.], 

1279 [4532., 4874.], 

1280 [4664., 5018.], 

1281 [4796., 5162.], 

1282 [4928., 5306.]]) 

1283 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 

1284 array([[4400., 4730.], 

1285 [4532., 4874.], 

1286 [4664., 5018.], 

1287 [4796., 5162.], 

1288 [4928., 5306.]]) 

1289 

1290 Writeable returned arrays (since version 1.10.0): 

1291 

1292 >>> a = np.zeros((3, 3)) 

1293 >>> np.einsum('ii->i', a)[:] = 1 

1294 >>> a 

1295 array([[1., 0., 0.], 

1296 [0., 1., 0.], 

1297 [0., 0., 1.]]) 

1298 

1299 Example of ellipsis use: 

1300 

1301 >>> a = np.arange(6).reshape((3,2)) 

1302 >>> b = np.arange(12).reshape((4,3)) 

1303 >>> np.einsum('ki,jk->ij', a, b) 

1304 array([[10, 28, 46, 64], 

1305 [13, 40, 67, 94]]) 

1306 >>> np.einsum('ki,...k->i...', a, b) 

1307 array([[10, 28, 46, 64], 

1308 [13, 40, 67, 94]]) 

1309 >>> np.einsum('k...,jk', a, b) 

1310 array([[10, 28, 46, 64], 

1311 [13, 40, 67, 94]]) 

1312 

1313 Chained array operations. For more complicated contractions, speed ups 

1314 might be achieved by repeatedly computing a 'greedy' path or pre-computing the 

1315 'optimal' path and repeatedly applying it, using an 

1316 `einsum_path` insertion (since version 1.12.0). Performance improvements can be 

1317 particularly significant with larger arrays: 

1318 

1319 >>> a = np.ones(64).reshape(2,4,8) 

1320 

1321 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 

1322 

1323 >>> for iteration in range(500): 

1324 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 

1325 

1326 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 

1327 

1328 >>> for iteration in range(500): 

1329 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') 

1330 

1331 Greedy `einsum` (faster optimal path approximation): ~160ms 

1332 

1333 >>> for iteration in range(500): 

1334 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 

1335 

1336 Optimal `einsum` (best usage pattern in some use cases): ~110ms 

1337 

1338 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] 

1339 >>> for iteration in range(500): 

1340 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 

1341 

1342 """ 

1343 # Special handling if out is specified 

1344 specified_out = out is not None 

1345 

1346 # If no optimization, run pure einsum 

1347 if optimize is False: 

1348 if specified_out: 

1349 kwargs['out'] = out 

1350 return c_einsum(*operands, **kwargs) 

1351 

1352 # Check the kwargs to avoid a more cryptic error later, without having to 

1353 # repeat default values here 

1354 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 

1355 unknown_kwargs = [k for (k, v) in kwargs.items() if 

1356 k not in valid_einsum_kwargs] 

1357 if len(unknown_kwargs): 

1358 raise TypeError("Did not understand the following kwargs: %s" 

1359 % unknown_kwargs) 

1360 

1361 

1362 # Build the contraction list and operand 

1363 operands, contraction_list = einsum_path(*operands, optimize=optimize, 

1364 einsum_call=True) 

1365 

1366 # Start contraction loop 

1367 for num, contraction in enumerate(contraction_list): 

1368 inds, idx_rm, einsum_str, remaining, blas = contraction 

1369 tmp_operands = [operands.pop(x) for x in inds] 

1370 

1371 # Do we need to deal with the output? 

1372 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

1373 

1374 # Call tensordot if still possible 

1375 if blas: 

1376 # Checks have already been handled 

1377 input_str, results_index = einsum_str.split('->') 

1378 input_left, input_right = input_str.split(',') 

1379 

1380 tensor_result = input_left + input_right 

1381 for s in idx_rm: 

1382 tensor_result = tensor_result.replace(s, "") 

1383 

1384 # Find indices to contract over 

1385 left_pos, right_pos = [], [] 

1386 for s in sorted(idx_rm): 

1387 left_pos.append(input_left.find(s)) 

1388 right_pos.append(input_right.find(s)) 

1389 

1390 # Contract! 

1391 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos))) 

1392 

1393 # Build a new view if needed 

1394 if (tensor_result != results_index) or handle_out: 

1395 if handle_out: 

1396 kwargs["out"] = out 

1397 new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs) 

1398 

1399 # Call einsum 

1400 else: 

1401 # If out was specified 

1402 if handle_out: 

1403 kwargs["out"] = out 

1404 

1405 # Do the contraction 

1406 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 

1407 

1408 # Append new items and dereference what we can 

1409 operands.append(new_view) 

1410 del tmp_operands, new_view 

1411 

1412 if specified_out: 

1413 return out 

1414 else: 

1415 return operands[0]