Searched refs:primitive_transposes (Results 1 – 8 of 8) sorted by relevance
254 return primitive_transposes[p]424 primitive_transposes: Dict[core.Primitive, Callable] = {} variable429 primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)445 primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)489 primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)554 primitive_transposes[core.call_p] = partial(call_transpose, call_p)585 primitive_transposes[pe.remat_call_p] = remat_transpose689 primitive_transposes[custom_lin_p] = _custom_lin_transpose739 primitive_transposes[fun_lin_p] = fun_lin_transpose
905 ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)1444 ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
1247 ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
2401 ad.primitive_transposes[conj_p] = _conj_transpose_rule2493 ad.primitive_transposes[add_p] = _add_transpose2516 ad.primitive_transposes[sub_p] = _sub_transpose2534 ad.primitive_transposes[div_p] = _div_transpose_rule3430 ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule3819 ad.primitive_transposes[select_p] = _select_transpose_rule4096 ad.primitive_transposes[dynamic_update_slice_p] = \4343 ad.primitive_transposes[gather_p] = _gather_transpose_rule4653 ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule5481 ad.primitive_transposes[select_and_scatter_add_p] = \[all …]
546 ad.primitive_transposes[while_p] = _while_transpose_error1108 ad.primitive_transposes[cond_p] = _cond_transpose1876 ad.primitive_transposes[scan_p] = _scan_transpose2359 ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
662 ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
679 ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule1042 ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule
374 ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose