1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsoncodec
8
9import (
10	"reflect"
11	"sync"
12
13	"go.mongodb.org/mongo-driver/bson/bsonrw"
14	"go.mongodb.org/mongo-driver/bson/bsontype"
15)
16
17var _ ValueEncoder = &PointerCodec{}
18var _ ValueDecoder = &PointerCodec{}
19
20// PointerCodec is the Codec used for pointers.
21type PointerCodec struct {
22	ecache map[reflect.Type]ValueEncoder
23	dcache map[reflect.Type]ValueDecoder
24	l      sync.RWMutex
25}
26
27// NewPointerCodec returns a PointerCodec that has been initialized.
28func NewPointerCodec() *PointerCodec {
29	return &PointerCodec{
30		ecache: make(map[reflect.Type]ValueEncoder),
31		dcache: make(map[reflect.Type]ValueDecoder),
32	}
33}
34
35// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
36// or looking up an encoder for the type of value the pointer points to.
37func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
38	if val.Kind() != reflect.Ptr {
39		if !val.IsValid() {
40			return vw.WriteNull()
41		}
42		return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
43	}
44
45	if val.IsNil() {
46		return vw.WriteNull()
47	}
48
49	pc.l.RLock()
50	enc, ok := pc.ecache[val.Type()]
51	pc.l.RUnlock()
52	if ok {
53		if enc == nil {
54			return ErrNoEncoder{Type: val.Type()}
55		}
56		return enc.EncodeValue(ec, vw, val.Elem())
57	}
58
59	enc, err := ec.LookupEncoder(val.Type().Elem())
60	pc.l.Lock()
61	pc.ecache[val.Type()] = enc
62	pc.l.Unlock()
63	if err != nil {
64		return err
65	}
66
67	return enc.EncodeValue(ec, vw, val.Elem())
68}
69
70// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
71// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
72func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
73	if !val.CanSet() || val.Kind() != reflect.Ptr {
74		return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
75	}
76
77	if vr.Type() == bsontype.Null {
78		val.Set(reflect.Zero(val.Type()))
79		return vr.ReadNull()
80	}
81	if vr.Type() == bsontype.Undefined {
82		val.Set(reflect.Zero(val.Type()))
83		return vr.ReadUndefined()
84	}
85
86	if val.IsNil() {
87		val.Set(reflect.New(val.Type().Elem()))
88	}
89
90	pc.l.RLock()
91	dec, ok := pc.dcache[val.Type()]
92	pc.l.RUnlock()
93	if ok {
94		if dec == nil {
95			return ErrNoDecoder{Type: val.Type()}
96		}
97		return dec.DecodeValue(dc, vr, val.Elem())
98	}
99
100	dec, err := dc.LookupDecoder(val.Type().Elem())
101	pc.l.Lock()
102	pc.dcache[val.Type()] = dec
103	pc.l.Unlock()
104	if err != nil {
105		return err
106	}
107
108	return dec.DecodeValue(dc, vr, val.Elem())
109}
110