1// Copyright (c) 2018 The btcsuite developers 2// Use of this source code is governed by an ISC 3// license that can be found in the LICENSE file. 4 5package psbt 6 7import ( 8 "bytes" 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "io" 13 "sort" 14 15 "github.com/btcsuite/btcd/txscript" 16 "github.com/btcsuite/btcd/wire" 17) 18 19// WriteTxWitness is a utility function due to non-exported witness 20// serialization (writeTxWitness encodes the bitcoin protocol encoding for a 21// transaction input's witness into w). 22func WriteTxWitness(w io.Writer, wit [][]byte) error { 23 if err := wire.WriteVarInt(w, 0, uint64(len(wit))); err != nil { 24 return err 25 } 26 27 for _, item := range wit { 28 err := wire.WriteVarBytes(w, 0, item) 29 if err != nil { 30 return err 31 } 32 } 33 return nil 34} 35 36// writePKHWitness writes a witness for a p2wkh spending input 37func writePKHWitness(sig []byte, pub []byte) ([]byte, error) { 38 var ( 39 buf bytes.Buffer 40 witnessItems = [][]byte{sig, pub} 41 ) 42 43 if err := WriteTxWitness(&buf, witnessItems); err != nil { 44 return nil, err 45 } 46 47 return buf.Bytes(), nil 48} 49 50// checkIsMultisigScript is a utility function to check whether a given 51// redeemscript fits the standard multisig template used in all P2SH based 52// multisig, given a set of pubkeys for redemption. 53func checkIsMultiSigScript(pubKeys [][]byte, sigs [][]byte, 54 script []byte) bool { 55 56 // First insist that the script type is multisig. 57 if txscript.GetScriptClass(script) != txscript.MultiSigTy { 58 return false 59 } 60 61 // Inspect the script to ensure that the number of sigs and pubkeys is 62 // correct 63 _, numSigs, err := txscript.CalcMultiSigStats(script) 64 if err != nil { 65 return false 66 } 67 68 // If the number of sigs provided, doesn't match the number of required 69 // pubkeys, then we can't proceed as we're not yet final. 70 if numSigs != len(pubKeys) || numSigs != len(sigs) { 71 return false 72 } 73 74 return true 75} 76 77// extractKeyOrderFromScript is a utility function to extract an ordered list 78// of signatures, given a serialized script (redeemscript or witness script), a 79// list of pubkeys and the signatures corresponding to those pubkeys. This 80// function is used to ensure that the signatures will be embedded in the final 81// scriptSig or scriptWitness in the correct order. 82func extractKeyOrderFromScript(script []byte, expectedPubkeys [][]byte, 83 sigs [][]byte) ([][]byte, error) { 84 85 // If this isn't a proper finalized multi-sig script, then we can't 86 // proceed. 87 if !checkIsMultiSigScript(expectedPubkeys, sigs, script) { 88 return nil, ErrUnsupportedScriptType 89 } 90 91 // Arrange the pubkeys and sigs into a slice of format: 92 // * [[pub,sig], [pub,sig],..] 93 type sigWithPub struct { 94 pubKey []byte 95 sig []byte 96 } 97 var pubsSigs []sigWithPub 98 for i, pub := range expectedPubkeys { 99 pubsSigs = append(pubsSigs, sigWithPub{ 100 pubKey: pub, 101 sig: sigs[i], 102 }) 103 } 104 105 // Now that we have the set of (pubkey, sig) pairs, we'll construct a 106 // position map that we can use to swap the order in the slice above to 107 // match how things are laid out in the script. 108 type positionEntry struct { 109 index int 110 value sigWithPub 111 } 112 var positionMap []positionEntry 113 114 // For each pubkey in our pubsSigs slice, we'll now construct a proper 115 // positionMap entry, based on _where_ in the script the pubkey first 116 // appears. 117 for _, p := range pubsSigs { 118 pos := bytes.Index(script, p.pubKey) 119 if pos < 0 { 120 return nil, errors.New("script does not contain pubkeys") 121 } 122 123 positionMap = append(positionMap, positionEntry{ 124 index: pos, 125 value: p, 126 }) 127 } 128 129 // Now that we have the position map full populated, we'll use the 130 // index data to properly sort the entries in the map based on where 131 // they appear in the script. 132 sort.Slice(positionMap, func(i, j int) bool { 133 return positionMap[i].index < positionMap[j].index 134 }) 135 136 // Finally, we can simply iterate through the position map in order to 137 // extract the proper signature ordering. 138 sortedSigs := make([][]byte, 0, len(positionMap)) 139 for _, x := range positionMap { 140 sortedSigs = append(sortedSigs, x.value.sig) 141 } 142 143 return sortedSigs, nil 144} 145 146// getMultisigScriptWitness creates a full psbt serialized Witness field for 147// the transaction, given the public keys and signatures to be appended. This 148// function will only accept witnessScripts of the type M of N multisig. This 149// is used for both p2wsh and nested p2wsh multisig cases. 150func getMultisigScriptWitness(witnessScript []byte, pubKeys [][]byte, 151 sigs [][]byte) ([]byte, error) { 152 153 // First using the script as a guide, we'll properly order the sigs 154 // according to how their corresponding pubkeys appear in the 155 // witnessScript. 156 orderedSigs, err := extractKeyOrderFromScript( 157 witnessScript, pubKeys, sigs, 158 ) 159 if err != nil { 160 return nil, err 161 } 162 163 // Now that we know the proper order, we'll append each of the 164 // signatures into a new witness stack, then top it off with the 165 // witness script at the end, prepending the nil as we need the extra 166 // pop.. 167 witnessElements := make(wire.TxWitness, 0, len(sigs)+2) 168 witnessElements = append(witnessElements, nil) 169 for _, os := range orderedSigs { 170 witnessElements = append(witnessElements, os) 171 } 172 witnessElements = append(witnessElements, witnessScript) 173 174 // Now that we have the full witness stack, we'll serialize it in the 175 // expected format, and return the final bytes. 176 var buf bytes.Buffer 177 if err = WriteTxWitness(&buf, witnessElements); err != nil { 178 return nil, err 179 } 180 return buf.Bytes(), nil 181} 182 183// checkSigHashFlags compares the sighash flag byte on a signature with the 184// value expected according to any PsbtInSighashType field in this section of 185// the PSBT, and returns true if they match, false otherwise. 186// If no SighashType field exists, it is assumed to be SIGHASH_ALL. 187// 188// TODO(waxwing): sighash type not restricted to one byte in future? 189func checkSigHashFlags(sig []byte, input *PInput) bool { 190 expectedSighashType := txscript.SigHashAll 191 if input.SighashType != 0 { 192 expectedSighashType = input.SighashType 193 } 194 195 return expectedSighashType == txscript.SigHashType(sig[len(sig)-1]) 196} 197 198// serializeKVpair writes out a kv pair using a varbyte prefix for each. 199func serializeKVpair(w io.Writer, key []byte, value []byte) error { 200 if err := wire.WriteVarBytes(w, 0, key); err != nil { 201 return err 202 } 203 204 return wire.WriteVarBytes(w, 0, value) 205} 206 207// serializeKVPairWithType writes out to the passed writer a type coupled with 208// a key. 209func serializeKVPairWithType(w io.Writer, kt uint8, keydata []byte, 210 value []byte) error { 211 212 // If the key has no data, then we write a blank slice. 213 if keydata == nil { 214 keydata = []byte{} 215 } 216 217 // The final key to be written is: {type} || {keyData} 218 serializedKey := append([]byte{kt}, keydata...) 219 return serializeKVpair(w, serializedKey, value) 220} 221 222// getKey retrieves a single key - both the key type and the keydata (if 223// present) from the stream and returns the key type as an integer, or -1 if 224// the key was of zero length. This integer is is used to indicate the presence 225// of a separator byte which indicates the end of a given key-value pair list, 226// and the keydata as a byte slice or nil if none is present. 227func getKey(r io.Reader) (int, []byte, error) { 228 229 // For the key, we read the varint separately, instead of using the 230 // available ReadVarBytes, because we have a specific treatment of 0x00 231 // here: 232 count, err := wire.ReadVarInt(r, 0) 233 if err != nil { 234 return -1, nil, ErrInvalidPsbtFormat 235 } 236 if count == 0 { 237 // A separator indicates end of key-value pair list. 238 return -1, nil, nil 239 } 240 241 // Check that we don't attempt to decode a dangerously large key. 242 if count > MaxPsbtKeyLength { 243 return -1, nil, ErrInvalidKeydata 244 } 245 246 // Next, we ready out the designated number of bytes, which may include 247 // a type, key, and optional data. 248 keyTypeAndData := make([]byte, count) 249 if _, err := io.ReadFull(r, keyTypeAndData[:]); err != nil { 250 return -1, nil, err 251 } 252 253 keyType := int(string(keyTypeAndData)[0]) 254 255 // Note that the second return value will usually be empty, since most 256 // keys contain no more than the key type byte. 257 if len(keyTypeAndData) == 1 { 258 return keyType, nil, nil 259 } 260 261 // Otherwise, we return the key, along with any data that it may 262 // contain. 263 return keyType, keyTypeAndData[1:], nil 264 265} 266 267// readTxOut is a limited version of wire.ReadTxOut, because the latter is not 268// exported. 269func readTxOut(txout []byte) (*wire.TxOut, error) { 270 if len(txout) < 10 { 271 return nil, ErrInvalidPsbtFormat 272 } 273 274 valueSer := binary.LittleEndian.Uint64(txout[:8]) 275 scriptPubKey := txout[9:] 276 277 return wire.NewTxOut(int64(valueSer), scriptPubKey), nil 278} 279 280// SumUtxoInputValues tries to extract the sum of all inputs specified in the 281// UTXO fields of the PSBT. An error is returned if an input is specified that 282// does not contain any UTXO information. 283func SumUtxoInputValues(packet *Packet) (int64, error) { 284 // We take the TX ins of the unsigned TX as the truth for how many 285 // inputs there should be, as the fields in the extra data part of the 286 // PSBT can be empty. 287 if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { 288 return 0, fmt.Errorf("TX input length doesn't match PSBT " + 289 "input length") 290 } 291 292 inputSum := int64(0) 293 for idx, in := range packet.Inputs { 294 switch { 295 case in.WitnessUtxo != nil: 296 // Witness UTXOs only need to reference the TxOut. 297 inputSum += in.WitnessUtxo.Value 298 299 case in.NonWitnessUtxo != nil: 300 // Non-witness UTXOs reference to the whole transaction 301 // the UTXO resides in. 302 utxOuts := in.NonWitnessUtxo.TxOut 303 txIn := packet.UnsignedTx.TxIn[idx] 304 inputSum += utxOuts[txIn.PreviousOutPoint.Index].Value 305 306 default: 307 return 0, fmt.Errorf("input %d has no UTXO information", 308 idx) 309 } 310 } 311 return inputSum, nil 312} 313 314// TxOutsEqual returns true if two transaction outputs are equal. 315func TxOutsEqual(out1, out2 *wire.TxOut) bool { 316 if out1 == nil || out2 == nil { 317 return out1 == out2 318 } 319 return out1.Value == out2.Value && 320 bytes.Equal(out1.PkScript, out2.PkScript) 321} 322 323// VerifyOutputsEqual verifies that the two slices of transaction outputs are 324// deep equal to each other. We do the length check and manual loop to provide 325// better error messages to the user than just returning "not equal". 326func VerifyOutputsEqual(outs1, outs2 []*wire.TxOut) error { 327 if len(outs1) != len(outs2) { 328 return fmt.Errorf("number of outputs are different") 329 } 330 for idx, out := range outs1 { 331 // There is a byte slice in the output so we can't use the 332 // equality operator. 333 if !TxOutsEqual(out, outs2[idx]) { 334 return fmt.Errorf("output %d is different", idx) 335 } 336 } 337 return nil 338} 339 340// VerifyInputPrevOutpointsEqual verifies that the previous outpoints of the 341// two slices of transaction inputs are deep equal to each other. We do the 342// length check and manual loop to provide better error messages to the user 343// than just returning "not equal". 344func VerifyInputPrevOutpointsEqual(ins1, ins2 []*wire.TxIn) error { 345 if len(ins1) != len(ins2) { 346 return fmt.Errorf("number of inputs are different") 347 } 348 for idx, in := range ins1 { 349 if in.PreviousOutPoint != ins2[idx].PreviousOutPoint { 350 return fmt.Errorf("previous outpoint of input %d is "+ 351 "different", idx) 352 } 353 } 354 return nil 355} 356 357// VerifyInputOutputLen makes sure a packet is non-nil, contains a non-nil wire 358// transaction and that the wire input/output lengths match the partial input/ 359// output lengths. A caller also can specify if they expect any inputs and/or 360// outputs to be contained in the packet. 361func VerifyInputOutputLen(packet *Packet, needInputs, needOutputs bool) error { 362 if packet == nil || packet.UnsignedTx == nil { 363 return fmt.Errorf("PSBT packet cannot be nil") 364 } 365 366 if len(packet.UnsignedTx.TxIn) != len(packet.Inputs) { 367 return fmt.Errorf("invalid PSBT, wire inputs don't match " + 368 "partial inputs") 369 } 370 if len(packet.UnsignedTx.TxOut) != len(packet.Outputs) { 371 return fmt.Errorf("invalid PSBT, wire outputs don't match " + 372 "partial outputs") 373 } 374 375 if needInputs && len(packet.UnsignedTx.TxIn) == 0 { 376 return fmt.Errorf("PSBT packet must contain at least one " + 377 "input") 378 } 379 if needOutputs && len(packet.UnsignedTx.TxOut) == 0 { 380 return fmt.Errorf("PSBT packet must contain at least one " + 381 "output") 382 } 383 384 return nil 385} 386 387// NewFromSignedTx is a utility function to create a packet from an 388// already-signed transaction. Returned are: an unsigned transaction 389// serialization, a list of scriptSigs, one per input, and a list of witnesses, 390// one per input. 391func NewFromSignedTx(tx *wire.MsgTx) (*Packet, [][]byte, 392 []wire.TxWitness, error) { 393 394 scriptSigs := make([][]byte, 0, len(tx.TxIn)) 395 witnesses := make([]wire.TxWitness, 0, len(tx.TxIn)) 396 tx2 := tx.Copy() 397 398 // Blank out signature info in inputs 399 for i, tin := range tx2.TxIn { 400 tin.SignatureScript = nil 401 scriptSigs = append(scriptSigs, tx.TxIn[i].SignatureScript) 402 tin.Witness = nil 403 witnesses = append(witnesses, tx.TxIn[i].Witness) 404 } 405 406 // Outputs always contain: (value, scriptPubkey) so don't need 407 // amending. Now tx2 is tx with all signing data stripped out 408 unsignedPsbt, err := NewFromUnsignedTx(tx2) 409 if err != nil { 410 return nil, nil, nil, err 411 } 412 return unsignedPsbt, scriptSigs, witnesses, nil 413} 414