1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18
19# Scope for collecting child 'Block's
20use strict;
21use warnings;
22use AI::MXNet::Gluon::Parameter;
23package AI::MXNet::Gluon::BlockScope;
24use AI::MXNet::Function::Parameters;
25my $_current;
26use Mouse;
27has '_block'      => (is => 'ro', init_arg => 'block', weak_ref => 1);
28has [qw/_counter _old_scope
29    _name_scope/] => (is => 'rw', init_arg => undef);
30
31sub BUILD
32{
33    my $self = shift;
34    $self->_counter({});
35}
36
37# Creates prefix and params for new Block.
38method create($prefix, $params, $hint)
39{
40    my $current = $_current;
41    if(not defined $current)
42    {
43        if(not defined $prefix)
44        {
45            $prefix = AI::MXNet::Symbol::NameManager->current->get(undef, $hint) . '_';
46        }
47        if(not defined $params)
48        {
49            $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $prefix);
50        }
51        else
52        {
53            $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $params->prefix, shared => $params);
54        }
55        return ($prefix, $params);
56    }
57
58    if(not defined $prefix)
59    {
60        my $count = $current->_counter->{ $hint } // 0;
61        $prefix = sprintf('%s%d_', $hint, $count);
62        $current->_counter->{$hint} = $count + 1;
63    }
64    if(not defined $params)
65    {
66        my $parent = $current->_block->params;
67        $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $parent->prefix.$prefix, shared => $parent->_shared);
68    }
69    else
70    {
71        $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $params->prefix, $params);
72    }
73    return ($current->_block->prefix.$prefix, $params);
74}
75
76method __enter__()
77{
78    return $self if $self->_block->_empty_prefix;
79    $self->_old_scope($_current);
80    $_current = $self;
81    $self->_name_scope(AI::MXNet::Symbol::NameManager->current);
82    AI::MXNet::Symbol::NameManager->set_current(AI::MXNet::Symbol::Prefix->new(prefix => $self->_block->prefix));
83    return $self;
84}
85
86method __exit__()
87{
88    return if $self->_block->_empty_prefix;
89    AI::MXNet::Symbol::NameManager->set_current($self->_name_scope);
90    $self->_name_scope(undef);
91    $_current = $self->_old_scope;
92}
93
94package AI::MXNet::Gluon::Block;
95use AI::MXNet::Gluon::Mouse;
96use Scalar::Util qw(refaddr);
97
98=head2 NAME
99
100    AI::MXNet::Gluon::Block - Base class for all neural network layers and models.
101
102=head2 DESCRIPTION
103
104    Base class for all neural network layers and models. Your models should
105    subclass this class.
106
107    AI::MXNet::Gluon::Block can be nested recursively in a tree structure. You can create and
108    assign child AI::MXNet::Gluon::Block as regular attributes
109
110    use AI::MXNet::Gluon::NN qw(nn);
111    use AI::MXNet qw(mx);
112
113    package Model;
114    use AI::MXNet::Gluon::Mouse;
115    use AI::MXNet::Function::Parameters;
116    extends 'AI::MXNet::Gluon::Block';
117
118    sub BUILD
119    {
120        my $self = shift;
121        $self->name_scope(sub {
122            $self->dense0(nn->Dense(5, in_units=>5));
123            $self->dense1(nn->Dense(5, in_units=>5));
124        });
125    }
126
127    method forward($x)
128    {
129        return $self->dense1->($self->dense0->($x));
130    }
131
132    my $model = Model->new()
133    $model->initialize(ctx=>mx->cpu(0))
134    $model->(nd->zeros([10, 10], ctx=>mx->cpu(0)));
135
136
137    Child AI::MXNet::Gluon::Block assigned this way will be registered and ->collect_params
138    will collect their Parameters recursively.
139
140    Parameters
141    ----------
142    Prefix acts like a name space. All children blocks created in parent block's
143    name_scope will have parent block's prefix in their name.
144    Please refer to
145    naming tutorial https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/naming.html
146    for more info on prefix and naming.
147
148    params : AI::MXNet::Gluon::ParameterDict or undef
149        AI::MXNet::Gluon::ParameterDict for sharing weights with the new AI::MXNet::Gluon::Block. For example,
150        if you want `dense1` to share `dense0`'s weights, you can do
151
152        $dense0 = nn->Dense(20);
153        $dense1 = nn->Dense(20, params=>dense0->collect_params());
154=cut
155
156method _flatten(
157    $args
158)
159{
160    if(blessed $args and $args->isa('AI::MXNet::NDArray'))
161    {
162        return ([$args], 0);
163    }
164    elsif(blessed $args and $args->isa('AI::MXNet::Symbol'))
165    {
166        my $length = @{ $args->list_outputs() };
167        $length = $length > 1 ? $length : 0;
168        return ([$args], $length)
169    }
170    my @flat;
171    my @fmts;
172    for my $i (@{ $args })
173    {
174        my ($arg, $fmt) = __PACKAGE__->_flatten($i);
175        push @flat, @{ $arg };
176        push @fmts, $fmt;
177    }
178    return (\@flat, \@fmts);
179}
180
181method _regroup(
182    $args, $fmt
183)
184{
185    my $in_symbol = (blessed $args and $args->isa('AI::MXNet::Symbol'));
186    my @ret;
187    if(not ref $fmt)
188    {
189        my $len = @{$args} - 1;
190        if($fmt == 0)
191        {
192            @ret = ([@{$args}[1..$len]]);
193            if($in_symbol)
194            {
195                $ret[0] = AI::MXNet::Symbol->Group($ret[0]);
196            }
197            return (@{$args}[0], $ret[0]);
198        }
199        @ret = ([@{$args}[0..$fmt-1]], [@{$args}[$fmt..$len]]);
200        if($in_symbol)
201        {
202            @ret = map { AI::MXNet::Symbol->Group($_) } @ret;
203        }
204        return @ret;
205    }
206    for my $i (@{ $fmt })
207    {
208        my $res;
209        ($res, $args) = __PACKAGE__->_regroup($args, $i);
210        push @ret, $res;
211    }
212    return (\@ret, $args);
213}
214
215has _prefix => (is => 'rw', init_arg => 'prefix', isa => 'Str');
216has _params => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::Gluon::ParameterDict]');
217has [qw/_name _scope _empty_prefix/] => (is => 'rw', init_arg => undef);
218has [qw/_children _forward_hooks _forward_pre_hooks/]  => (is => 'rw', init_arg => undef, default => sub { Hash::Ordered->new });
219has '_reg_params' => (is => 'rw', init_arg => undef, default => sub { +{} });
220around BUILDARGS => \&AI::MXNet::Base::process_arguments;
221
222sub AUTOLOAD {
223    my $name = $AI::MXNet::Gluon::Block::AUTOLOAD;
224    $name =~ s/.*:://;
225    my $self = shift;
226    AI::MXNet::Gluon::Mouse::has($name => (is => 'rw', 'init_arg' => undef, 'caller' => ref $self));
227    $self->$name(@_);
228}
229
230sub BUILD
231{
232    my $self = shift;
233    $self->_empty_prefix(defined $self->_prefix and $self->_prefix eq '');
234    my ($prefix, $params) = AI::MXNet::Gluon::BlockScope->create($self->_prefix, $self->_params, $self->_alias);
235    $self->_prefix($prefix);
236    $self->_params($params);
237    my $name = $prefix;
238    $name =~ s/_$//;
239    $self->_name($name);
240    $self->_scope(AI::MXNet::Gluon::BlockScope->new(block => $self));
241}
242
243method _class_name()
244{
245    my $class = ref $self || $self;
246    $class =~ s/^.+:://;
247    $class;
248}
249
250method __setattr__($name, $current, $prev=)
251{
252    if(defined $prev)
253    {
254        if(
255            (
256                blessed $prev
257                    and
258                ($prev->isa('AI::MXNet::Gluon::Parameter') or $prev->isa('AI::MXNet::Gluon::Block'))
259            )
260            and not (blessed $current and (ref($prev) eq ref($current)))
261        )
262        {
263            confess(
264                sprintf(
265                    "Changing attribute type for %s from %s to %s is not allowed.",
266                    $self->name,
267                    ref($prev),
268                    ref($current)||'no ref'
269                )
270            );
271        }
272    }
273    if(blessed $current and $current->isa('AI::MXNet::Gluon::Block'))
274    {
275        $self->register_child($current, $name);
276    }
277    elsif(blessed $current and $current->isa('AI::MXNet::Gluon::Parameter'))
278    {
279        if(exists $self->_reg_params->{ $name })
280        {
281            confess("Overriding Parameter attribute $name is not allowed. ".
282                "If you want to share parameters between blocks, please set".
283                "'params' at Block construction instead."
284            );
285        }
286        $self->_reg_params->{ $name } = $current;
287    }
288}
289
290method _check_container_with_block()
291{
292    my $_find_unregistered_block_in_container;
293    my %children = map { refaddr($_) => 1 } $self->_children->values;
294    $_find_unregistered_block_in_container = sub { my ($data) = @_;
295    # Find whether a nested container structure contains Blocks
296        if(ref $data eq 'ARRAY')
297        {
298            for my $ele (@{ $data })
299            {
300                if($_find_unregistered_block_in_container->($ele))
301                {
302                    return 1
303                }
304            }
305            return 0;
306        }
307        elsif(ref $data eq 'HASH')
308        {
309            for my $v (values %$data)
310            {
311                if($_find_unregistered_block_in_container->($v))
312                {
313                    return 1;
314                }
315            }
316            return 0;
317        }
318        elsif(blessed $data and $data->isa('AI::MXNet::Gluon::Block'))
319        {
320            return not exists $children{ refaddr($data) };
321        }
322        else
323        {
324            return 0;
325        }
326    };
327    my $attributes_hash = $self->attributes_hash();
328    while(my ($k, $v) = each %{ $attributes_hash })
329    {
330        if((ref $v eq 'HASH' or ref $v eq 'ARRAY') and not $k =~ /^__/)
331        {
332            if($_find_unregistered_block_in_container->($v))
333            {
334                AI::MXNet::Logging->warning(
335                    '"%s" is a unregsitered container with Blocks. '.
336                    'Note that Blocks inside the list, tuple or dict will not be '.
337                    'registered automatically. Make sure to register them using '.
338                    'register_child() or switching to '.
339                    'nn->Sequential/nn->HybridSequential instead. ',
340                    $self->_class_name.'.'.$k
341                );
342            }
343        }
344    }
345}
346
347method _alias()
348{
349    lc $self->_class_name;
350}
351
352method attributes_hash()
353{
354    +{ map { $_ => $self->$_ } $self->meta->get_attribute_list };
355}
356
357use overload
358    '""' => sub
359    {
360        my $self = shift;
361        my $s = "%s(\n%s\n)";
362        my @blocks;
363        my %attributes_hash = %{ $self->attributes_hash };
364        while(my ($k, $v) = each %attributes_hash)
365        {
366            if(blessed $v and $v->isa(__PACKAGE__))
367            {
368                push @blocks, "  ($k): ".AI::MXNet::Base::_indent("$v", 2);
369            }
370        }
371        sprintf("%s(\n%s\n)", $self->_class_name, join("\n", @blocks));
372    },
373    '&{}' => sub { my $self = shift; sub { $self->call(@_) } };
374
375method prefix()
376{
377    $self->_prefix;
378}
379
380method name()
381{
382    $self->_name;
383}
384
385method class()
386{
387    __PACKAGE__;
388}
389
390method name_scope(CodeRef $sub)
391{
392    $self->_scope->__enter__;
393    eval { $sub->(); };
394    my $err = $@;
395    $self->_scope->__exit__;
396    confess($err) if $err;
397}
398
399=head2 params
400
401        Returns this `Block`'s parameter dictionary (does not include its
402        children's parameters).
403=cut
404
405method params()
406{
407    return $self->_params;
408}
409
410=head2 collect_params
411
412        Returns a AI::MXNet::Gluon::ParameterDict containing this AI::MXNet::Gluon::Block and all of its
413        children's Parameters(default), also can returns the ParameterDict
414        with parameters that match a regular expression.
415
416        For example, collects parameters specified in ['conv1_weight', 'conv1_bias', 'fc_weight',
417        'fc_bias'
418
419            $model->collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
420
421        or collects all parameters that have the name end with 'weight' or 'bias', this can be done
422        using regular expressions.
423
424            $model->collect_params('.*weight|.*bias')
425
426=cut
427
428method collect_params(Maybe[Str] $select=)
429{
430    $self->_check_container_with_block();
431    my $ret = AI::MXNet::Gluon::ParameterDict->new(prefix => $self->_params->prefix);
432    $ret->update($self->params, $select);
433    for my $cld ($self->_children->values)
434    {
435        $ret->update($cld->collect_params($select));
436    }
437    return $ret;
438}
439
440
441method _collect_params_with_prefix(Str $prefix='')
442{
443    if($prefix)
444    {
445        $prefix .= '.';
446    }
447    my %ret = map { $prefix.$_ => $self->_reg_params->{ $_ } } keys %{ $self->_reg_params };
448    my $iter = $self->_children->iterator;
449    while(my ($name, $child) = $iter->())
450    {
451        %ret = (%ret, %{ $child->_collect_params_with_prefix("$prefix$name") });
452    }
453    return \%ret;
454}
455
456=head2 save_parameters
457
458        Save parameters to file.
459
460        filename : str
461            Path to file.
462=cut
463
464method save_parameters(Str $filename)
465{
466    my $params = $self->_collect_params_with_prefix();
467    my %arg_dict = map { $_ => $params->{$_}->_reduce } keys %{ $params };
468    AI::MXNet::NDArray->save($filename, \%arg_dict);
469}
470
471=head2 load_parameters
472
473        Load parameters from file.
474
475        $filename : str
476            Path to parameter file.
477        :$ctx= : Context or list of Context
478            Context(s) initialize loaded parameters on.
479        :$allow_missing : bool, default False
480            Whether to silently skip loading parameters not represents in the file.
481        :$ignore_extra : bool, default False
482            Whether to silently ignore parameters from the file that are not
483            present in this Block.
484=cut
485
486method load_parameters(
487    Str   $filename,
488    AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx,
489    Bool  :$allow_missing=0,
490    Bool  :$ignore_extra=0
491)
492{
493    my $loaded = AI::MXNet::NDArray->load($filename);
494    my $params = $self->_collect_params_with_prefix;
495    return if not keys %$loaded and not keys %$params;
496
497    if(not grep { /\./ } keys %$loaded)
498    {
499        # legacy loading
500        %$loaded = ();
501        $self->collect_params->load(
502            $filename,
503            ($ctx ? (ctx   => $ctx) : ()),
504            allow_missing  => $allow_missing,
505            ignore_extra   => $ignore_extra,
506            restore_prefix => $self->prefix
507        );
508        return;
509    }
510
511    if(not $allow_missing)
512    {
513        for my $name (keys %$params)
514        {
515            if(not exists $loaded->{$name})
516            {
517                confess(
518                    "Parameter $name is missing in file $filename, which contains parameters:".
519                    join(',', keys %$loaded)."\n".
520                    "Set allow_missing=>1 to ignore missing parameters."
521                );
522            }
523        }
524    }
525    for my $name (keys %$loaded)
526    {
527        if(not $ignore_extra and not exists $params->{ $name })
528        {
529            confess(
530                "Parameter $name loaded from file $filename is not present in ParameterDict, ".
531                "which contains parameters ".
532                join(',', keys %$params)."\n".
533                "Set ignore_extra=>1 to ignore."
534            );
535        }
536        $params->{$name}->_load_init($loaded->{$name}, $ctx) if exists $params->{$name};
537    }
538}
539
540=head2 register_child
541
542        Registers block as a child of self. `Block`s assigned to self as
543        attributes will be registered automatically.
544=cut
545
546method register_child(AI::MXNet::Gluon::Block $block, Maybe[Str] $name=)
547{
548    $name //= $self->_children->keys;
549    $self->_children->set($name, $block);
550}
551
552=head2 register_forward_pre_hook
553
554        Registers a forward pre-hook on the block.
555
556        The hook function is called immediately before 'forward'.
557        It should not modify the input or output.
558
559        Parameters
560        ----------
561        $hook : CodeRef or callable object
562            The forward hook function of form $hook->($block, $input).
563
564        Returns
565        -------
566        AI::MXNet::Gluon::Utils::HookHandle
567=cut
568
569method register_forward_pre_hook($hook)
570{
571    my $handle = AI::MXNet::Gluon::Utils::HookHandle->new;
572    $handle->attach($self->_forward_pre_hooks, $hook);
573    return $handle;
574}
575
576=head2 register_forward_hook
577
578        Registers a forward hook on the block.
579
580        The hook function is called immediately after 'forward'.
581        It should not modify the input or output.
582
583        Parameters
584        ----------
585        $hook : CodeRef or callable object
586            The forward hook function of form $hook->($block, $input).
587
588        Returns
589        -------
590        AI::MXNet::Gluon::Utils::HookHandle
591=cut
592
593method register_forward_hook($hook)
594{
595    my $handle = AI::MXNet::Gluon::Utils::HookHandle->new;
596    $handle->attach($self->_forward_hooks, $hook);
597    return $handle;
598}
599
600=head2 apply
601
602        Applies $fn recursively to every child block as well as self.
603
604        Parameters
605        ----------
606        $fn : callable
607            Function to be applied to each submodule, of form `$fn->($block)`.
608
609        Returns
610        -------
611        this block
612=cut
613
614method apply($fn)
615{
616    for my $cld ($self->_children->values)
617    {
618        $cld->apply($fn);
619    }
620    $fn->($self);
621    return $self;
622}
623
624=head2 initialize
625
626
627        Initializes AI::MXNet::Gluon::Parameters of this AI::MXNet::Gluon::Block and its children.
628        Equivalent to $block->collect_params()->initialize(...)
629
630        Parameters
631        ----------
632        $init : Initializer
633            Global default Initializer to be used when Parameter->init is undefined`.
634            Otherwise, Parameter->init takes precedence.
635        ctx : Context or array ref of Context
636            Keeps a copy of Parameters on one or many context(s).
637        verbose : bool, default False
638            Whether to verbosely print out details on initialization.
639        force_reinit : bool, default False
640            Whether to force re-initialization if parameter is already initialized.
641=cut
642
643method initialize(
644    Initializer $init=AI::MXNet::Initializer->Uniform(),
645    AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx,
646    Bool :$verbose=0,
647    Bool :$force_reinit=0
648)
649{
650    $self->collect_params->initialize(init => $init, ctx => $ctx, verbose => $verbose, force_reinit => $force_reinit);
651}
652
653
654=head2 hybridize
655
656        Activates or deactivates `HybridBlock`s recursively. Has no effect on
657        non-hybrid children.
658
659        Parameters
660        ----------
661        $active : bool, default True
662            Whether to turn hybrid on or off.
663        :$static_alloc : bool, default False
664            Statically allocate memory to improve speed. Memory usage may increase.
665        :$static_shape : bool, default False
666            Optimize for invariant input shapes between iterations. Must also
667            set static_alloc to True. Change of input shapes is still allowed
668            but slower.
669=cut
670
671method hybridize(
672    Bool $active=1,
673    %args
674)
675{
676    $_->hybridize(
677        $active,
678        %args
679    ) for $self->_children->values;
680}
681
682=head2 cast
683
684        Cast this Block to use another data type.
685
686        Parameters
687        ----------
688        dtype : Dtype
689            The new data type.
690=cut
691
692method cast(Dtype $dtype)
693{
694    for my $child ($self->_children->values)
695    {
696        $child->cast($dtype);
697    }
698    $_->cast($dtype) for $self->params->values;
699}
700
701method call(@args)
702{
703    for my $hook ($self->_forward_pre_hooks->values)
704    {
705        $hook->($self, \@args);
706    }
707    my @out = $self->forward(@args);
708    for my $hook ($self->_forward_hooks->values)
709    {
710        $hook->($self, \@args, \@out);
711    }
712    return wantarray ? @out : $out[0];
713}
714
715=head2 forward
716
717        Overrides to implement forward computation using `NDArray`. Only
718        accepts positional arguments.
719
720        Parameters
721        ----------
722        @args : array of NDArray
723            Input tensors.
724=cut
725
726method forward(@args)
727{
728    confess("Not Implemented");
729}
730
731method register(Str $container)
732{
733    my $sub_name = $self->_class_name;
734    my $dest = $self->can('new');
735    my $func = sub {
736        splice @_, 0, 1, $self;
737        goto $dest;
738    };
739    no strict 'refs';
740    *{"$container\::$sub_name"} = $func;
741}
742
743=head2 summary
744
745        Print the summary of the model's output and parameters.
746
747        The network must have been initialized, and must not have been hybridized.
748
749        Parameters
750        ----------
751        @inputs : objects
752            Any inputs that the model supports. For any tensor in the input, only
753            AI::MXNet::NDArray is supported.
754=cut
755
756method summary(@inputs)
757{
758    my $summary = Hash::Ordered->new;
759    my %seen;
760    my @hooks;
761    my $stringify;
762    $stringify = sub {
763        my $in = shift;
764        if(ref($in) eq 'ARRAY')
765        {
766            return '('.join(', ', map { $stringify->($_) } @$in).')';
767        }
768         else
769        {
770            return "$in";
771        }
772    };
773    my $_get_shape_str = sub { my ($args) = @_;
774        $args = $args->[0] if(ref $args eq 'ARRAY' and @$args == 1);
775        my ($flat_args, $fmts) = __PACKAGE__->_flatten($args);
776        my $flat_arg_shapes = [map { (blessed($_) and $_->isa('AI::MXNet::NDArray')) ? $_->shape : $_ } @$flat_args];
777        my $shapes = (__PACKAGE__->_regroup($flat_arg_shapes, $fmts))[0];
778        my $shape_str = $stringify->($shapes);
779        $shape_str =~ s/L//g;
780        return $shape_str;
781    };
782
783    my $_register_summary_hook = sub { my ($block) = @_;
784        unless(not $block->isa('AI::MXNet::Gluon:::HybridBlock') or not $block->_active)
785        {
786            confess("\"${\ $block->name }\" must not be hybridized to print summary.");
787        }
788        my $_summary_hook = sub { my ($block, undef, $outputs) = @_;
789            my $class_name = $block->_class_name;
790            my $block_idx = $summary->keys - 1;
791
792            my $m_key = sprintf('%s-%i', $class_name, $block_idx+1);
793            $summary->set($m_key, Hash::Ordered->new);
794            $summary->get($m_key)->set('output_shape', $_get_shape_str->($outputs));
795
796            my $params = 0;
797            $summary->get($m_key)->set('trainable', 0);
798            $summary->get($m_key)->set('shared', 0);
799            for my $p (values %{ $block->_reg_params })
800            {
801                $params += $p->data->size;
802                $summary->get($m_key)->set('trainable', $summary->get($m_key)->get('trainable') + ($p->grad_req eq 'null' ? 0 : $p->data->size));
803                if(exists $seen{$p})
804                {
805                    $summary->get($m_key)->set('shared', $summary->get($m_key)->get('shared') + $p->data->size);
806                }
807                else
808                {
809                    $seen{$p} = 1;
810                }
811            }
812            $summary->get($m_key)->set('n_params', $params);
813        };
814
815        if(not $block->isa('AI::MXNet::Gluon::NN::Sequential') and not $block->isa('AI::MXNet::Gluon::NN::HybridSequential'))
816        {
817            push @hooks, $block->register_forward_hook($_summary_hook);
818        }
819    };
820
821    my $input = Hash::Ordered->new;
822    $summary->set('Input', $input);
823    $input->set('output_shape', $_get_shape_str->(\@inputs));
824    $input->set('n_params', 0);
825    $input->set('trainable', 0);
826    $input->set('shared', 0);
827
828    eval {
829        $self->apply($_register_summary_hook);
830        $self->(@inputs);
831
832        my $line_format = "%20s  %42s %15s\n";
833        print (('-')x80, "\n");
834        printf($line_format, 'Layer (type)', 'Output Shape', 'Param #');
835        print (('=')x80, "\n");
836        my $total_params = 0;
837        my $trainable_params = 0;
838        my $shared_params = 0;
839        for my $layer ($summary->keys)
840        {
841            printf($line_format, $layer, $summary->get($layer)->get('output_shape'), $summary->get($layer)->get('n_params'));
842            $total_params += $summary->get($layer)->get('n_params');
843            $trainable_params += $summary->get($layer)->get('trainable');
844            $shared_params += $summary->get($layer)->get('shared');
845        }
846        print (('=')x80, "\n");
847        print "Parameters in forward computation graph, duplicate included\n";
848        print "   Total params: $total_params\n";
849        print "   Non-trainable params: ", $total_params - $trainable_params, "\n";
850        print "Shared params in forward computation graph: $shared_params\n";
851        print "Unique parameters in model: ", $total_params - $shared_params, "\n";
852        print (('-')x80, "\n");
853    };
854    $_->detach for @hooks;
855}
856
857__PACKAGE__->register('AI::MXNet::Gluon');
858
859package AI::MXNet::Gluon::HybridBlock;
860=head2 NAME
861
862    AI::MXNet::Gluon::HybridBlock
863
864=head2 DESCRIPTION
865
866    HybridBlock supports forwarding with both Symbol and NDArray.
867
868    Forward computation in HybridBlock must be static to work with Symbols,
869    i.e. you cannot call aspdl, shape, dtype, etc on tensors.
870    Also, you cannot use branching or loop logic that bases on non-constant
871    expressions like random numbers or intermediate results, since they change
872    the graph structure for each iteration.
873
874    Before activating with hybridize(), HybridBlock works just like normal
875    Block. After activation, HybridBlock will create a symbolic graph
876    representing the forward computation and cache it. On subsequent forwards,
877    the cached graph will be used instead of hybrid_forward.
878
879    Refer Hybrid tutorial L<https://mxnet.io/tutorials/gluon/hybrid.html> to see
880    the end-to-end usage.
881=cut
882
883use AI::MXNet::Gluon::Mouse;
884use AI::MXNet::Base;
885extends 'AI::MXNet::Gluon::Block';
886has [qw/
887        _cached_graph
888        _cached_op
889        _out_format _in_format
890        _active _flags _cached_op_args
891/] => (is => 'rw', init_arg => undef);
892
893sub BUILD
894{
895    my $self = shift;
896    $self->_active(0);
897    $self->_flags([]);
898    $self->_cached_graph([]);
899    $self->_cached_op_args([]);
900}
901
902method __setattr__($name, $current, $prev=)
903{
904    $self->SUPER::__setattr__($name, $current, $prev);
905    if(blessed $current and $current->isa('AI::MXNet::Gluon::HybridBlock'))
906    {
907        $self->_clear_cached_op();
908    }
909}
910
911method register_child(AI::MXNet::Gluon::HybridBlock $block, Maybe[Str] $name=)
912{
913    $self->SUPER::register_child($block, $name);
914    $self->_clear_cached_op();
915}
916
917method hybridize(@args)
918{
919    my $active;
920    if(@args%2)
921    {
922        $active = shift(@args);
923    }
924    else
925    {
926        $active = 1;
927    }
928    $self->_active($active);
929    @{ $self->_flags } = @args;
930    $self->_clear_cached_op();
931    if($self->_active and ($self->_forward_hooks or $self->_forward_pre_hooks))
932    {
933        AI::MXNet::Logging->warning(
934            "$self is being hybridized while still having forward hook/pre-hook. ".
935            "If $self is a child of HybridBlock, the hooks will not take effect."
936        );
937    }
938    $self->SUPER::hybridize($self->_active, @args);
939}
940
941method cast(Dtype $dtype)
942{
943    $self->_clear_cached_op;
944    $self->SUPER::cast($dtype);
945}
946
947method  _infer_attrs($infer_fn, $attr, @args)
948{
949    my ($inputs, $out) = $self->_get_graph(@args);
950    my ($args) = __PACKAGE__->_flatten([@args]);
951    my %in;
952    zip(sub {
953        my ($i, $j) = @_;
954        $in{ $i->name } = $j->$attr;
955    }, $inputs, $args);
956    my ($arg_attrs, $aux_attrs);
957    ($arg_attrs, undef, $aux_attrs) = $out->$infer_fn(%in);
958    if(not defined $arg_attrs)
959    {
960        confess($@);
961    }
962    my %sdict;
963    zip(sub {
964        my ($i, $j) = @_;
965        $sdict{ $i } = $j;
966    }, $out->list_arguments, $arg_attrs);
967    zip(sub {
968        my ($i, $j) = @_;
969        $sdict{ $i } = $j;
970    }, $out->list_auxiliary_states, $aux_attrs);
971
972    for my $i ($self->collect_params->values)
973    {
974        $i->$attr($sdict{ $i->name });
975    }
976}
977
978method infer_shape(@args)
979{
980    $self->_infer_attrs('infer_shape', 'shape', @args);
981}
982
983method infer_type(@args)
984{
985    $self->_infer_attrs('infer_type', 'dtype', @args);
986}
987
988method _get_graph(@args)
989{
990    if(not @{ $self->_cached_graph })
991    {
992        my $args = [@args];
993        my ($in_format, $out_format);
994        ($args, $in_format) = __PACKAGE__->_flatten($args);
995        $self->_in_format($in_format);
996        my @inputs;
997        if(@args > 1)
998        {
999            @inputs = map { AI::MXNet::Symbol->var("data_$_") } 0 .. @$args-1;
1000        }
1001        else
1002        {
1003            @inputs = (AI::MXNet::Symbol->var("data"))
1004        }
1005        my ($grouped_inputs) = __PACKAGE__->_regroup(\@inputs, $self->_in_format);
1006        my %params = map { $_ => $self->_reg_params->{$_}->var } keys %{ $self->_reg_params };
1007        my @out;
1008        $self->name_scope(sub {
1009            @out = $self->hybrid_forward('AI::MXNet::Symbol', @{ $grouped_inputs }, %params);
1010        });
1011        my $out = @out > 1 ? [@out] : $out[0];
1012        ($out, $out_format) = __PACKAGE__->_flatten($out);
1013        $self->_out_format($out_format);
1014        @{ $self->_cached_graph } = (\@inputs, AI::MXNet::Symbol->Group($out));
1015    }
1016    return @{ $self->_cached_graph };
1017}
1018
1019=head2 infer_shape
1020
1021        Infers shape of Parameters from inputs.
1022=cut
1023
1024method _build_cache(@args)
1025{
1026    my ($data, $out) = $self->_get_graph(@args);
1027    my $i = 0;
1028    my %data_names = map { $_->name => $i++ } @{ $data };
1029    my $params = $self->collect_params;
1030    my $input_names = $out->list_inputs;
1031    my %param_names = map { $_ => 1 } $params->keys;
1032    my %expected_names = map { $_ => 1 } @{ $input_names };
1033    for my $name (keys %expected_names)
1034    {
1035        assert(
1036            (exists $param_names{ $name } or exists $data_names{ $name }),
1037            "Unknown input to HybridBlock: $name"
1038        );
1039    }
1040    my $unused = join(', ', map { "$data_names{$_}-th" } grep { !exists $expected_names{ $_ } } keys %data_names);
1041    AI::MXNet::Logging->warn(
1042        "The $unused input to HybridBlock is not used by any ".
1043        "computation. Is this intended?"
1044    ) if $unused;
1045    $unused = join(', ', grep { !exists $expected_names{ $_ } } keys %param_names);
1046    AI::MXNet::Logging->warn(
1047        "Parameter %s is not used by any computation. " .
1048        "Is this intended?"
1049    ) if $unused;
1050
1051    my @data_indices;
1052    my @param_indices;
1053    $self->_cached_op_args([]);
1054    enumerate(sub {
1055        my ($i, $name) = @_;
1056        if(exists $data_names{ $name })
1057        {
1058            push @data_indices, $i;
1059            push @{ $self->_cached_op_args }, [1, $data_names{$name}];
1060        }
1061        else
1062        {
1063            push @param_indices, $i;
1064            push @{ $self->_cached_op_args }, [0, $params->params->get($name)];
1065        }
1066    }, $input_names);
1067    my %flags = (
1068        data_indices  => \@data_indices,
1069        param_indices => \@param_indices,
1070        @{ $self->_flags }
1071    );
1072    $self->_cached_op(AI::MXNet::CachedOp->new($out, \%flags));
1073}
1074
1075method _deferred_infer_shape(@args)
1076{
1077    eval {
1078        $self->infer_shape(@args)
1079    };
1080    if($@)
1081    {
1082        confess(
1083            "Deferred initialization failed because shape".
1084            " cannot be inferred. $@"
1085        );
1086    }
1087}
1088
1089method _clear_cached_op()
1090{
1091    $self->_cached_graph([]);
1092    $self->_cached_op(undef);
1093}
1094
1095use Data::Dumper;
1096method _call_cached_op(@args)
1097{
1098    if(not defined $self->_cached_op)
1099    {
1100        $self->_build_cache(@args);
1101    }
1102    my $args = [@args];
1103    my $fmt;
1104    ($args, $fmt) = __PACKAGE__->_flatten($args);
1105    assert((Dumper($fmt) eq Dumper($self->_in_format)), "Invalid input format");
1106    my @cargs;
1107    eval {
1108        @cargs = map { (not $_->[0]) ? $_->[1]->data() : $args->[$_->[1]] } @{ $self->_cached_op_args };
1109    };
1110    if($@)
1111    {
1112        if($@ =~ /DeferredInitializationError/)
1113        {
1114            $self->_deferred_infer_shape(@$args);
1115            @cargs = ();
1116            map {
1117                if($_->[0])
1118                {
1119                    push @cargs, $args->[$_->[1]];
1120                }
1121                else
1122                {
1123                    $_->[1]->_finish_deferred_init();
1124                    push @cargs, $_->[1]->data;
1125                }
1126            } @{ $self->_cached_op_args };
1127        }
1128        else
1129        {
1130            confess($@);
1131        }
1132    }
1133    my $out = $self->_cached_op->(@cargs);
1134    if(blessed $out and $out->isa('AI::MXNet::NDArray'))
1135    {
1136        $out = [$out];
1137    }
1138    my $ret = (__PACKAGE__->_regroup($out, $self->_out_format))[0];
1139    if(ref($ret) eq 'ARRAY' and wantarray)
1140    {
1141        return @$ret;
1142    }
1143    else
1144    {
1145        return $ret;
1146    }
1147}
1148
1149=head2 forward
1150
1151        Defines the forward computation. Arguments can be either
1152        NDArray or Symbol
1153=cut
1154
1155method forward($x, @args)
1156{
1157    if(blessed $x and $x->isa('AI::MXNet::NDArray'))
1158    {
1159        my @out;
1160        my $out;
1161        my $ctx = $x->context;
1162        my $current_ctx = AI::MXNet::Context->current_ctx;
1163        AI::MXNet::Context->set_current($ctx);
1164        if($self->_active)
1165        {
1166            if(wantarray)
1167            {
1168                my @out = $self->_call_cached_op($x, @args);
1169                AI::MXNet::Context->set_current($current_ctx);
1170                return @out;
1171            }
1172            else
1173            {
1174                my $out = $self->_call_cached_op($x, @args);
1175                AI::MXNet::Context->set_current($current_ctx);
1176                return $out;
1177            }
1178        }
1179        my %params;
1180        eval {
1181            %params = map { $_ => $self->_reg_params->{ $_ }->data($ctx) } keys %{ $self->_reg_params };
1182        };
1183        if($@)
1184        {
1185            if($@ =~ /DeferredInitializationError/)
1186            {
1187                $self->_deferred_infer_shape($x, @args);
1188                $_->_finish_deferred_init for $self->params->values;
1189                %params = map { $_ => $self->_reg_params->{ $_ }->data($ctx) } keys %{ $self->_reg_params };
1190            }
1191            else
1192            {
1193                confess($@);
1194            }
1195        }
1196        @out = $self->hybrid_forward('AI::MXNet::NDArray', $x, @args, %params);
1197        AI::MXNet::Context->set_current($current_ctx);
1198        return wantarray ? @out : $out[0];
1199    }
1200    assert(
1201        (blessed $x and $x->isa('AI::MXNet::Symbol')),
1202        "HybridBlock requires the first argument to forward be either ".
1203        "Symbol or NDArray, but got [".ref($x)."]"
1204    );
1205    my %params = map { $_ => $self->_reg_params->{ $_ }->var } keys %{ $self->_reg_params };
1206    my @ret;
1207    $self->name_scope(sub {
1208        @ret = $self->hybrid_forward('AI::MXNet::Symbol', $x, @args, %params);
1209    });
1210    return wantarray ? @ret : $ret[0];
1211}
1212
1213=head2 hybrid_forward
1214
1215        Overrides to construct symbolic graph for this `Block`.
1216
1217        Parameters
1218        ----------
1219        x : Symbol or NDArray
1220            The first input tensor.
1221        *args : list of Symbol or list of NDArray
1222            Additional input tensors.
1223=cut
1224
1225method hybrid_forward($F, $x, @args)
1226{
1227    confess("NotImplementedError");
1228}
1229
1230=head2 export
1231
1232        Export HybridBlock to json format that can be loaded by AI::MXNet::Module
1233        or the C++ interface.
1234
1235        When there are only one input, it will have name 'data'. When there
1236        Are more than one inputs, they will be named as 'data0', 'data1', etc.
1237
1238        Parameters
1239        ----------
1240        $path : str
1241            Path to save model. Two files 'path-symbol.json' and 'path-xxxx.params'
1242            will be created, where xxxx is the 4 digits epoch number.
1243        :$epoch=0 : Int
1244            Epoch number of saved model.
1245=cut
1246
1247method export(Str $path, :$epoch=0)
1248{
1249    if(not @{ $self->_cached_graph })
1250    {
1251        confess(
1252            "Please first call \$block->hybridize() and then run forward with ".
1253            "this block at least once before calling export."
1254        );
1255    }
1256    my $sym = $self->_cached_graph->[1];
1257    $sym->save("$path-symbol.json");
1258
1259    my %arg_names = map { $_ => 1 } @{ $sym->list_arguments };
1260    my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states };
1261    my %arg_dict;
1262    my $params = $self->collect_params;
1263    for my $name ($params->keys)
1264    {
1265        my $param = $params->get($name);
1266        if(exists $arg_names{ $name })
1267        {
1268            $arg_dict{ "arg:$name" } = $param->_reduce;
1269        }
1270        else
1271        {
1272            assert(exists $aux_names{ $name });
1273            $arg_dict{ "aux:$name" } = $param->_reduce;
1274        }
1275    }
1276    AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict);
1277}
1278
1279__PACKAGE__->register('AI::MXNet::Gluon');
1280
1281package AI::MXNet::Gluon::SymbolBlock;
1282use AI::MXNet::Gluon::Mouse;
1283use AI::MXNet::Base;
1284extends 'AI::MXNet::Gluon::HybridBlock';
1285
1286=head1 NAME
1287
1288    AI::MXNet::Gluon::SymbolBlock - Construct block from symbol.
1289=cut
1290
1291=head1 DESCRIPTION
1292
1293    Construct block from symbol. This is useful for using pre-trained models
1294    as feature extractors. For example, you may want to extract get the output
1295    from fc2 layer in AlexNet.
1296
1297    Parameters
1298    ----------
1299    outputs : Symbol or list of Symbol
1300        The desired output for SymbolBlock.
1301    inputs : Symbol or list of Symbol
1302        The Variables in output's argument that should be used as inputs.
1303    params : ParameterDict
1304        Parameter dictionary for arguments and auxililary states of outputs
1305        that are not inputs.
1306
1307    Examples
1308    --------
1309    >>> # To extract the feature from fc1 and fc2 layers of AlexNet
1310    >>> $alexnet = gluon->model_zoo->vision->alexnet(pretrained=>1, ctx=>mx->cpu(),
1311                                                 prefix=>'model_');
1312    >>> $inputs = mx->sym->var('data');
1313    >>> $out = $alexnet->($inputs);
1314    >>> $internals = $out->get_internals()
1315    >>> print($internals->list_outputs())
1316    ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
1317    >>> $outputs = [$internals->slice('model_dense0_relu_fwd_output'),
1318                   $internals->slice('model_dense1_relu_fwd_output')];
1319    >>> # Create SymbolBlock that shares parameters with alexnet
1320    >>> $feat_model = gluon->SymbolBlock($outputs, $inputs, params=>$alexnet->collect_params());
1321    >>> $x = mx->nd->random_normal(shape=>[16, 3, 224, 224]);
1322    >>> print($feat_model->($x));
1323=cut
1324
1325has [qw/outputs inputs/] => (is => 'rw', isa => 'AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]');
1326method python_constructor_arguments() { [qw/outputs inputs/] }
1327
1328sub BUILD
1329{
1330    my ($self, $orig_params) = @_;
1331    return unless defined $self->outputs and defined $self->inputs;
1332    $self->_prefix('');
1333    $self->_params(AI::MXNet::Gluon::ParameterDict->new(prefix => '', shared => $orig_params->{params}));
1334    if(blessed $self->inputs and @{ $self->inputs->list_outputs } == 1)
1335    {
1336        $self->inputs([$self->inputs]);
1337    }
1338    if(not blessed $self->outputs and @{ $self->outputs } == 1)
1339    {
1340        $self->outputs($self->outputs->[0]);
1341    }
1342    my ($syms, $in_format) = __PACKAGE__->_flatten($self->inputs);
1343    my ($out, $out_format) = __PACKAGE__->_flatten($self->outputs);
1344    $self->_in_format($in_format);
1345    $self->_out_format($out_format);
1346    $out = AI::MXNet::Symbol->Group($out);
1347
1348    my %input_names;
1349    for my $i (@{ $syms })
1350    {
1351        assert(
1352            (@{ $i->get_internals->list_outputs() } == 1),
1353            "Input symbols must be variable, but $i is an output of operators"
1354        );
1355        $input_names{ $i->name } = 1;
1356    }
1357
1358    # check if any symbol is row_sparse
1359    my $row_sparse_storage = STORAGE_TYPE_STR_TO_ID->{row_sparse};
1360    for my $i (@{ $out })
1361    {
1362        for my $j (@{ $i->get_internals })
1363        {
1364            assert(
1365                (not defined $j->attr("__storage_type__") or $j->attr("__storage_type__") ne $row_sparse_storage),
1366                "SymbolBlock doesn't support Parameter ${\ $j->name }  because its storage ".
1367                "type is 'row_sparse'."
1368            );
1369        }
1370    }
1371
1372    my $arg_params = $out->list_arguments;
1373    my $aux_params = $out->list_auxiliary_states;
1374    my ($arg_types, $aux_types) = _infer_param_types($syms, $out, $arg_params, $aux_params);
1375
1376    for(enumerate($arg_params))
1377    {
1378        my ($i, $arg) = @$_;
1379        if(not exists $input_names{ $arg })
1380        {
1381            $self->params->get($arg, allow_deferred_init => 1, dtype => $arg_types->[$i]);
1382        }
1383    }
1384
1385    for(enumerate($aux_params))
1386    {
1387        my ($i, $arg) = @$_;
1388        if(not exists $input_names{ $arg })
1389        {
1390            $self->params->get($arg, grad_req => 'null', allow_deferred_init => 1, dtype => $aux_types->[$i]);
1391        }
1392    }
1393
1394    $self->_cached_graph([$syms, $out]);
1395    my $prefix = _common_prefix($self->_params->keys);
1396    my %params = $self->_params->items;
1397    while(my ($key, $val) = each %params)
1398    {
1399        $key =~ s/^$prefix//;
1400        $self->_reg_params->{ $key } = $val;
1401    }
1402    $self->_prefix($prefix);
1403}
1404
1405
1406func _infer_param_types($in_params, $out_params, $arg_params, $aux_params, $default_dtype='float32')
1407{
1408    # Utility function that helps in inferring DType of args and auxs params
1409    # from given input param.
1410    # Parameters
1411    # ----------
1412    # in_params: array ref of AI::MXNet::Symbol objects
1413    #     List of input symbol variables.
1414    # out_params: AI::MXNet::Symbol
1415    #     Output symbol variable.
1416    # arg_params: array ref of Str
1417    #     List of names of argument parametrs.
1418    # aux_params: array ref of Str
1419    #     List of names of auxiliary parameters.
1420    # default_dtype: Dtype, default 'float32'
1421    #     Default data type for arg_params and aux_params, if unable to infer the type.
1422    #  Returns
1423    # -------
1424    # arg_types: Array ref of Dtype
1425    #     List of arg_params type. Order is same as arg_params.
1426    #     Defaults to 'float32', if unable to infer type.
1427    # aux_types: Array ref of Dtype
1428    #     List of aux_params type. Order is same as aux_params.
1429    #     Defaults to 'float32', if unable to infer type.
1430
1431    my $arg_types;
1432    my $aux_types;
1433    # Get Input symbol details. This will be used to infer types of
1434    # other parameters.
1435    my @input_sym_names = map { $_->name } @{ $in_params };
1436    # Try to infer input types. If not successful, we will set default dtype.
1437    # If successful, we will try to infer other params in the graph.
1438    my @input_sym_arg_types;
1439    my $can_infer_input_type = 1;
1440    for my $in_param(@{ $in_params })
1441    {
1442        my $input_sym_arg_type = ($in_param->infer_type)[0];
1443        if(not $input_sym_arg_type or @$input_sym_arg_type < 1)
1444        {
1445            $can_infer_input_type = 0;
1446            last;
1447        }
1448        else
1449        {
1450            push @input_sym_arg_types, $input_sym_arg_type->[0];
1451        }
1452    }
1453    # Try to infer types of other parameters.
1454    if($can_infer_input_type)
1455    {
1456        my %params = map { $_->[0] => $_->[1] } zip(\@input_sym_names, \@input_sym_arg_types);
1457        ($arg_types, undef, $aux_types) = $out_params->infer_type(%params);
1458        if(not defined $arg_types or @$arg_types != @$arg_params)
1459        {
1460            $arg_types = [($default_dtype)x@$arg_params];
1461        }
1462        if(not defined $aux_types or @$aux_types != @$aux_params)
1463        {
1464            $aux_types = [($default_dtype)x@$aux_params];
1465        }
1466    }
1467    return ($arg_types, $aux_types);
1468}
1469
1470func _common_prefix(@names)
1471{
1472    if(not @names)
1473    {
1474        return ''
1475    }
1476    my $prefix = $names[0];
1477    for my $name (@names)
1478    {
1479        my $i = 0;
1480        while($i < length($prefix) and $i < length($name) and substr($prefix, $i, 1) eq substr($name, $i, 1))
1481        {
1482            $i++;
1483        }
1484        $prefix = substr($prefix, 0, $i);
1485    }
1486    return $prefix;
1487}
1488
1489method forward($x, @args)
1490{
1491    if(blessed $x and $x->isa('AI::MXNet::NDArray'))
1492    {
1493        my @out;
1494        my $out;
1495        my $ctx = $x->context;
1496        my $current_ctx = AI::MXNet::Context->current_ctx;
1497        AI::MXNet::Context->set_current($ctx);
1498        if(wantarray)
1499        {
1500            my @out = $self->_call_cached_op($x, @args);
1501            AI::MXNet::Context->set_current($current_ctx);
1502            return @out;
1503        }
1504        else
1505        {
1506            my $out = $self->_call_cached_op($x, @args);
1507            AI::MXNet::Context->set_current($current_ctx);
1508            return $out;
1509        }
1510    }
1511    assert(
1512        (blessed $x and $x->isa('AI::MXNet::Symbol')),
1513        "HybridBlock requires the first argument to forward be either ".
1514        "Symbol or NDArray, but got [".ref($x)."]"
1515    );
1516    my $args = \@args;
1517    my $in_fmt;
1518    ($args, $in_fmt) = __PACKAGE__->_flatten([$x, @$args]);
1519    assert((Data::Dumper::Dumper($in_fmt) eq Data::Dumper::Dumper($self->_in_format)), "Invalid input format");
1520    my $ret = $self->_cached_graph->[1]->deepcopy;
1521    my %in;
1522    for(zip($self->_cached_graph->[0], $args)) {
1523        my ($k, $v) = @$_;
1524        $in{$k->name} = $v;
1525    }
1526    $ret->_compose(%in);
1527    $ret = (__PACKAGE__->_regroup($ret, $self->_out_format))[0];
1528    if(ref($ret) eq 'ARRAY' and wantarray)
1529    {
1530        return @$ret;
1531    }
1532    else
1533    {
1534        return $ret;
1535    }
1536}
1537
1538method _clear_cached_op()
1539{
1540    my $tmp = $self->_cached_graph;
1541    $self->SUPER::_clear_cached_op;
1542    $self->_cached_graph($tmp);
1543}
1544
1545method hybrid_forward(@args)
1546{
1547    confess('NotImplementedError');
1548}
1549
1550=head2 imports
1551
1552        Import model previously saved by HybridBlock->export or
1553        Module->save_checkpoint as a SymbolBlock for use in Gluon.
1554
1555        Parameters
1556        ----------
1557        $symbol_file : Str
1558            Path to symbol file.
1559        $input_names : Str|ArrayRef[Str]
1560            List of input variable names
1561        :$param_file : Str, optional
1562            Path to parameter file.
1563        $ctx : Context, default undef
1564            The context to initialize SymbolBlock on.
1565
1566        Returns
1567        -------
1568        SymbolBlock
1569            SymbolBlock loaded from symbol and parameter files.
1570=cut
1571
1572method imports(Str $symbol_file, Str|ArrayRef[Str] $input_names, Maybe [Str] $param_file=, Maybe[AI::MXNet::Context] $ctx=)
1573{
1574    my $sym = AI::MXNet::Symbol->load($symbol_file);
1575    $input_names = [$input_names] unless ref $input_names;
1576    my @inputs = map { AI::MXNet::Symbol->var($_) } @{ $input_names };
1577    my $ret = __PACKAGE__->new($sym, \@inputs);
1578    if(defined $param_file)
1579    {
1580        $ret->load_parameters($param_file, (defined $ctx ? (ctx=>$ctx) : ()));
1581    }
1582    return $ret
1583}
1584
1585__PACKAGE__->register('AI::MXNet::Gluon');
1586
15871;
1588